Jasmine402 commited on
Commit
e0fc0a8
·
verified ·
1 Parent(s): f18b1fb

Upload folder using huggingface_hub

Browse files
AnchorIT_CIR_app.py ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: AnchorIT ZS-CIR BNU
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.29.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: AnchorIT_ZS-CIR_BNU
3
+ app_file: AnchorIT_CIR_app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.44.1
 
 
6
  ---
 
 
__pycache__/aspect_ratio_template.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
__pycache__/base_model_template.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
__pycache__/brushedit_all_in_one_pipeline.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
__pycache__/llm_pipeline.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
__pycache__/llm_template.cpython-310.pyc ADDED
Binary file (562 Bytes). View file
 
__pycache__/vlm_pipeline.cpython-310.pyc ADDED
Binary file (6.85 kB). View file
 
__pycache__/vlm_pipeline_noqwen.cpython-310.pyc ADDED
Binary file (6.83 kB). View file
 
__pycache__/vlm_template.cpython-310.pyc ADDED
Binary file (1.3 kB). View file
 
aspect_ratio_template.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/TencentARC/PhotoMaker/pull/120 written by https://github.com/DiscoNova
2
+ # Note: Since output width & height need to be divisible by 8, the w & h -values do
3
+ # not exactly match the stated aspect ratios... but they are "close enough":)
4
+
5
+ aspect_ratio_list = [
6
+ {
7
+ "name": "Small Square (1:1)",
8
+ "w": 640,
9
+ "h": 640,
10
+ },
11
+ {
12
+ "name": "Custom resolution",
13
+ "w": "",
14
+ "h": "",
15
+ },
16
+ {
17
+ "name": "Instagram (1:1)",
18
+ "w": 1024,
19
+ "h": 1024,
20
+ },
21
+ {
22
+ "name": "35mm film / Landscape (3:2)",
23
+ "w": 1024,
24
+ "h": 680,
25
+ },
26
+ {
27
+ "name": "35mm film / Portrait (2:3)",
28
+ "w": 680,
29
+ "h": 1024,
30
+ },
31
+ {
32
+ "name": "CRT Monitor / Landscape (4:3)",
33
+ "w": 1024,
34
+ "h": 768,
35
+ },
36
+ {
37
+ "name": "CRT Monitor / Portrait (3:4)",
38
+ "w": 768,
39
+ "h": 1024,
40
+ },
41
+ {
42
+ "name": "Widescreen TV / Landscape (16:9)",
43
+ "w": 1024,
44
+ "h": 576,
45
+ },
46
+ {
47
+ "name": "Widescreen TV / Portrait (9:16)",
48
+ "w": 576,
49
+ "h": 1024,
50
+ },
51
+ {
52
+ "name": "Widescreen Monitor / Landscape (16:10)",
53
+ "w": 1024,
54
+ "h": 640,
55
+ },
56
+ {
57
+ "name": "Widescreen Monitor / Portrait (10:16)",
58
+ "w": 640,
59
+ "h": 1024,
60
+ },
61
+ {
62
+ "name": "Cinemascope (2.39:1)",
63
+ "w": 1024,
64
+ "h": 424,
65
+ },
66
+ {
67
+ "name": "Widescreen Movie (1.85:1)",
68
+ "w": 1024,
69
+ "h": 552,
70
+ },
71
+ {
72
+ "name": "Academy Movie (1.37:1)",
73
+ "w": 1024,
74
+ "h": 744,
75
+ },
76
+ {
77
+ "name": "Sheet-print (A-series) / Landscape (297:210)",
78
+ "w": 1024,
79
+ "h": 720,
80
+ },
81
+ {
82
+ "name": "Sheet-print (A-series) / Portrait (210:297)",
83
+ "w": 720,
84
+ "h": 1024,
85
+ },
86
+ ]
87
+
88
+ aspect_ratios = {k["name"]: (k["w"], k["h"]) for k in aspect_ratio_list}
base_model_template.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from huggingface_hub import snapshot_download
4
+
5
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
6
+
7
+
8
+
9
+ torch_dtype = torch.float16
10
+ device = "cpu"
11
+
12
+ BrushEdit_path = "models/"
13
+ if not os.path.exists(BrushEdit_path):
14
+ BrushEdit_path = snapshot_download(
15
+ repo_id="TencentARC/BrushEdit",
16
+ local_dir=BrushEdit_path,
17
+ token=os.getenv("HF_TOKEN"),
18
+ )
19
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
20
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
21
+
22
+
23
+ base_models_list = [
24
+ # {
25
+ # "name": "dreamshaper_8 (Preload)",
26
+ # "local_path": "models/base_model/dreamshaper_8",
27
+ # "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
28
+ # "models/base_model/dreamshaper_8", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
29
+ # ).to(device)
30
+ # },
31
+ # {
32
+ # "name": "epicrealism (Preload)",
33
+ # "local_path": "models/base_model/epicrealism_naturalSinRC1VAE",
34
+ # "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
35
+ # "models/base_model/epicrealism_naturalSinRC1VAE", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
36
+ # ).to(device)
37
+ # },
38
+ {
39
+ "name": "henmixReal (Preload)",
40
+ "local_path": "models/base_model/henmixReal_v5c",
41
+ "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
42
+ "models/base_model/henmixReal_v5c", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
43
+ ).to(device)
44
+ },
45
+ {
46
+ "name": "meinamix (Preload)",
47
+ "local_path": "models/base_model/meinamix_meinaV11",
48
+ "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
49
+ "models/base_model/meinamix_meinaV11", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
50
+ ).to(device)
51
+ },
52
+ {
53
+ "name": "realisticVision (Default)",
54
+ "local_path": "models/base_model/realisticVisionV60B1_v51VAE",
55
+ "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
56
+ "models/base_model/realisticVisionV60B1_v51VAE", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
57
+ ).to(device)
58
+ },
59
+ ]
60
+
61
+ base_models_template = {k["name"]: (k["local_path"], k["pipe"]) for k in base_models_list}
brushedit_all_in_one_pipeline.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageEnhance
2
+ from diffusers.image_processor import VaeImageProcessor
3
+
4
+ import numpy as np
5
+ import cv2
6
+
7
+
8
+
9
+ def BrushEdit_Pipeline(pipe,
10
+ prompts,
11
+ mask_np,
12
+ original_image,
13
+ generator,
14
+ num_inference_steps,
15
+ guidance_scale,
16
+ control_strength,
17
+ negative_prompt,
18
+ num_samples,
19
+ blending):
20
+ if mask_np.ndim != 3:
21
+ mask_np = mask_np[:, :, np.newaxis]
22
+
23
+ mask_np = mask_np / 255
24
+ height, width = mask_np.shape[0], mask_np.shape[1]
25
+ ## resize the mask and original image to the same size which is divisible by vae_scale_factor
26
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
27
+ height_new, width_new = image_processor.get_default_height_width(original_image, height, width)
28
+ mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis]
29
+ mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
30
+ mask_blurred = mask_blurred[:, :, np.newaxis]
31
+
32
+ original_image = cv2.resize(original_image, (width_new, height_new))
33
+
34
+ init_image = original_image * (1 - mask_np)
35
+ init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
36
+ mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB")
37
+
38
+ brushnet_conditioning_scale = float(control_strength)
39
+
40
+ images = pipe(
41
+ [prompts] * num_samples,
42
+ init_image,
43
+ mask_image,
44
+ num_inference_steps=num_inference_steps,
45
+ guidance_scale=guidance_scale,
46
+ generator=generator,
47
+ brushnet_conditioning_scale=brushnet_conditioning_scale,
48
+ negative_prompt=[negative_prompt]*num_samples,
49
+ height=height_new,
50
+ width=width_new,
51
+ ).images
52
+ ## convert to vae shape format, must be divisible by 8
53
+ original_image_pil = Image.fromarray(original_image).convert("RGB")
54
+ init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze())
55
+ init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255
56
+ init_image_np = init_image_np.astype(np.uint8)
57
+ if blending:
58
+ mask_blurred = mask_blurred * 0.5 + 0.5
59
+ image_all = []
60
+ for image_i in images:
61
+ image_np = np.array(image_i)
62
+ ## blending
63
+ image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np
64
+ image_pasted = image_pasted.astype(np.uint8)
65
+ image = Image.fromarray(image_pasted)
66
+ image_all.append(image)
67
+ else:
68
+ image_all = images
69
+
70
+
71
+ return image_all, mask_image, mask_np, init_image_np
72
+
73
+
brushedit_app.py ADDED
@@ -0,0 +1,1705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ # base_openai_url = "https://api.deepseek.com/"
39
+
40
+ #### Description ####
41
+ logo = r"""
42
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
43
+ """
44
+ head = r"""
45
+ <div style="text-align: center;">
46
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
47
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
48
+ <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
49
+ <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
50
+ <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
51
+
52
+ </div>
53
+ </br>
54
+ </div>
55
+ """
56
+ descriptions = r"""
57
+ Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
58
+ 🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
59
+ """
60
+
61
+ instructions = r"""
62
+ Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
63
+
64
+ 🛠️ <b>Fully automated instruction-based editing</b>:
65
+ <ul>
66
+ <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
67
+ <li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
68
+ <li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
69
+ </ul>
70
+
71
+ 🛠️ <b>Interactive instruction-based editing</b>:
72
+ <ul>
73
+ <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
74
+ <li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;"> to restore. </li>
75
+ <li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
76
+ <li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
77
+ </ul>
78
+
79
+ <b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
80
+
81
+ <b> We recommend zooming out in your browser for a better viewing range and experience. </b>
82
+
83
+ <b> For more detailed feature descriptions, see the bottom. </b>
84
+
85
+ ☕️ Have fun! 🎄 Wishing you a merry Christmas!
86
+ """
87
+
88
+ tips = r"""
89
+ 💡 <b>Some Tips</b>:
90
+ <ul>
91
+ <li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
92
+ <li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
93
+ <li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
94
+ </ul>
95
+
96
+ 💡 <b>Detailed Features</b>:
97
+ <ul>
98
+ <li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
99
+ <li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
100
+ <li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
101
+ <li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
102
+ <li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
103
+ <li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
104
+ <li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
105
+ <li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
106
+ <li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
107
+ <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
108
+ <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
109
+ </ul>
110
+
111
+ 💡 <b>Advanced Features</b>:
112
+ <ul>
113
+ <li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
114
+ <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
115
+ <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
116
+ <li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
117
+ <li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
118
+ <li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
119
+ </ul>
120
+
121
+
122
+ """
123
+
124
+
125
+
126
+ citation = r"""
127
+ If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
128
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/BrushEdit?style=social)](https://github.com/TencentARC/BrushEdit)
129
+ ---
130
+ 📝 **Citation**
131
+ <br>
132
+ If our work is useful for your research, please consider citing:
133
+ ```bibtex
134
+ @misc{li2024brushedit,
135
+ title={BrushEdit: All-In-One Image Inpainting and Editing},
136
+ author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
137
+ year={2024},
138
+ eprint={2412.10316},
139
+ archivePrefix={arXiv},
140
+ primaryClass={cs.CV}
141
+ }
142
+ ```
143
+ 📧 **Contact**
144
+ <br>
145
+ If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
146
+ """
147
+
148
+ # - - - - - examples - - - - - #
149
+ EXAMPLES = [
150
+
151
+ [
152
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
153
+ "add a magic hat on frog head.",
154
+ 642087011,
155
+ "frog",
156
+ "frog",
157
+ True,
158
+ False,
159
+ "GPT4-o (Highly Recommended)"
160
+ ],
161
+ [
162
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
163
+ "replace the background to ancient China.",
164
+ 648464818,
165
+ "chinese_girl",
166
+ "chinese_girl",
167
+ True,
168
+ False,
169
+ "GPT4-o (Highly Recommended)"
170
+ ],
171
+ [
172
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
173
+ "remove the deer.",
174
+ 648464818,
175
+ "angel_christmas",
176
+ "angel_christmas",
177
+ False,
178
+ False,
179
+ "GPT4-o (Highly Recommended)"
180
+ ],
181
+ [
182
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
183
+ "add a wreath on head.",
184
+ 648464818,
185
+ "sunflower_girl",
186
+ "sunflower_girl",
187
+ True,
188
+ False,
189
+ "GPT4-o (Highly Recommended)"
190
+ ],
191
+ [
192
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
193
+ "add a butterfly fairy.",
194
+ 648464818,
195
+ "girl_on_sun",
196
+ "girl_on_sun",
197
+ True,
198
+ False,
199
+ "GPT4-o (Highly Recommended)"
200
+ ],
201
+ [
202
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
203
+ "remove the christmas hat.",
204
+ 642087011,
205
+ "spider_man_rm",
206
+ "spider_man_rm",
207
+ False,
208
+ False,
209
+ "GPT4-o (Highly Recommended)"
210
+ ],
211
+ [
212
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
213
+ "remove the flower.",
214
+ 642087011,
215
+ "anime_flower",
216
+ "anime_flower",
217
+ False,
218
+ False,
219
+ "GPT4-o (Highly Recommended)"
220
+ ],
221
+ [
222
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
223
+ "replace the clothes to a delicated floral skirt.",
224
+ 648464818,
225
+ "chenduling",
226
+ "chenduling",
227
+ True,
228
+ False,
229
+ "GPT4-o (Highly Recommended)"
230
+ ],
231
+ [
232
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
233
+ "make the hedgehog in Italy.",
234
+ 648464818,
235
+ "hedgehog_rp_bg",
236
+ "hedgehog_rp_bg",
237
+ True,
238
+ False,
239
+ "GPT4-o (Highly Recommended)"
240
+ ],
241
+
242
+ ]
243
+
244
+ INPUT_IMAGE_PATH = {
245
+ "frog": "./assets/frog/frog.jpeg",
246
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
247
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
248
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
249
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
250
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
251
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
252
+ "chenduling": "./assets/chenduling/chengduling.jpg",
253
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
254
+ }
255
+ MASK_IMAGE_PATH = {
256
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
257
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
258
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
259
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
260
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
261
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
262
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
263
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
264
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
265
+ }
266
+ MASKED_IMAGE_PATH = {
267
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
268
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
269
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
270
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
271
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
272
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
273
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
274
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
275
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
276
+ }
277
+ OUTPUT_IMAGE_PATH = {
278
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
279
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
280
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
281
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
282
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
283
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
284
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
285
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
286
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
287
+ }
288
+
289
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
290
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
291
+
292
+ VLM_MODEL_NAMES = list(vlms_template.keys())
293
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
294
+ BASE_MODELS = list(base_models_template.keys())
295
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
296
+
297
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
298
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
299
+
300
+
301
+ ## init device
302
+ try:
303
+ if torch.cuda.is_available():
304
+ device = "cuda"
305
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
306
+ device = "mps"
307
+ else:
308
+ device = "cpu"
309
+ except:
310
+ device = "cpu"
311
+
312
+ # ## init torch dtype
313
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
314
+ # torch_dtype = torch.bfloat16
315
+ # else:
316
+ # torch_dtype = torch.float16
317
+
318
+ # if device == "mps":
319
+ # torch_dtype = torch.float16
320
+
321
+ torch_dtype = torch.float16
322
+
323
+
324
+
325
+ # download hf models
326
+ BrushEdit_path = "models/"
327
+ if not os.path.exists(BrushEdit_path):
328
+ BrushEdit_path = snapshot_download(
329
+ repo_id="TencentARC/BrushEdit",
330
+ local_dir=BrushEdit_path,
331
+ token=os.getenv("HF_TOKEN"),
332
+ )
333
+
334
+ ## init default VLM
335
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
336
+ if vlm_processor != "" and vlm_model != "":
337
+ vlm_model.to(device)
338
+ else:
339
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
340
+
341
+
342
+ ## init base model
343
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
344
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
345
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
346
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
347
+
348
+
349
+ # input brushnetX ckpt path
350
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
351
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
352
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
353
+ )
354
+ # speed up diffusion process with faster scheduler and memory optimization
355
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
356
+ # remove following line if xformers is not installed or when using Torch 2.0.
357
+ # pipe.enable_xformers_memory_efficient_attention()
358
+ pipe.enable_model_cpu_offload()
359
+
360
+
361
+ ## init SAM
362
+ sam = build_sam(checkpoint=sam_path)
363
+ sam.to(device=device)
364
+ sam_predictor = SamPredictor(sam)
365
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
366
+
367
+ ## init groundingdino_model
368
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
369
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
370
+
371
+ ## Ordinary function
372
+ def crop_and_resize(image: Image.Image,
373
+ target_width: int,
374
+ target_height: int) -> Image.Image:
375
+ """
376
+ Crops and resizes an image while preserving the aspect ratio.
377
+
378
+ Args:
379
+ image (Image.Image): Input PIL image to be cropped and resized.
380
+ target_width (int): Target width of the output image.
381
+ target_height (int): Target height of the output image.
382
+
383
+ Returns:
384
+ Image.Image: Cropped and resized image.
385
+ """
386
+ # Original dimensions
387
+ original_width, original_height = image.size
388
+ original_aspect = original_width / original_height
389
+ target_aspect = target_width / target_height
390
+
391
+ # Calculate crop box to maintain aspect ratio
392
+ if original_aspect > target_aspect:
393
+ # Crop horizontally
394
+ new_width = int(original_height * target_aspect)
395
+ new_height = original_height
396
+ left = (original_width - new_width) / 2
397
+ top = 0
398
+ right = left + new_width
399
+ bottom = original_height
400
+ else:
401
+ # Crop vertically
402
+ new_width = original_width
403
+ new_height = int(original_width / target_aspect)
404
+ left = 0
405
+ top = (original_height - new_height) / 2
406
+ right = original_width
407
+ bottom = top + new_height
408
+
409
+ # Crop and resize
410
+ cropped_image = image.crop((left, top, right, bottom))
411
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
412
+ return resized_image
413
+
414
+
415
+ ## Ordinary function
416
+ def resize(image: Image.Image,
417
+ target_width: int,
418
+ target_height: int) -> Image.Image:
419
+ """
420
+ Crops and resizes an image while preserving the aspect ratio.
421
+
422
+ Args:
423
+ image (Image.Image): Input PIL image to be cropped and resized.
424
+ target_width (int): Target width of the output image.
425
+ target_height (int): Target height of the output image.
426
+
427
+ Returns:
428
+ Image.Image: Cropped and resized image.
429
+ """
430
+ # Original dimensions
431
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
432
+ return resized_image
433
+
434
+
435
+ def move_mask_func(mask, direction, units):
436
+ binary_mask = mask.squeeze()>0
437
+ rows, cols = binary_mask.shape
438
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
439
+
440
+ if direction == 'down':
441
+ # move down
442
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
443
+
444
+ elif direction == 'up':
445
+ # move up
446
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
447
+
448
+ elif direction == 'right':
449
+ # move left
450
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
451
+
452
+ elif direction == 'left':
453
+ # move right
454
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
455
+
456
+ return moved_mask
457
+
458
+
459
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
460
+ # Randomly select the size of dilation
461
+ binary_mask = mask.squeeze()>0
462
+
463
+ if dilation_type == 'square_dilation':
464
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
465
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
466
+ elif dilation_type == 'square_erosion':
467
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
468
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
469
+ elif dilation_type == 'bounding_box':
470
+ # find the most left top and left bottom point
471
+ rows, cols = np.where(binary_mask)
472
+ if len(rows) == 0 or len(cols) == 0:
473
+ return mask # return original mask if no valid points
474
+
475
+ min_row = np.min(rows)
476
+ max_row = np.max(rows)
477
+ min_col = np.min(cols)
478
+ max_col = np.max(cols)
479
+
480
+ # create a bounding box
481
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
482
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
483
+
484
+ elif dilation_type == 'bounding_ellipse':
485
+ # find the most left top and left bottom point
486
+ rows, cols = np.where(binary_mask)
487
+ if len(rows) == 0 or len(cols) == 0:
488
+ return mask # return original mask if no valid points
489
+
490
+ min_row = np.min(rows)
491
+ max_row = np.max(rows)
492
+ min_col = np.min(cols)
493
+ max_col = np.max(cols)
494
+
495
+ # calculate the center and axis length of the ellipse
496
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
497
+ a = (max_col - min_col) // 2 # half long axis
498
+ b = (max_row - min_row) // 2 # half short axis
499
+
500
+ # create a bounding ellipse
501
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
502
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
503
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
504
+ dilated_mask[ellipse_mask] = True
505
+ else:
506
+ ValueError("dilation_type must be 'square' or 'ellipse'")
507
+
508
+ # use binary dilation
509
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
510
+ return dilated_mask
511
+
512
+
513
+ ## Gradio component function
514
+ def update_vlm_model(vlm_name):
515
+ global vlm_model, vlm_processor
516
+ if vlm_model is not None:
517
+ del vlm_model
518
+ torch.cuda.empty_cache()
519
+
520
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
521
+
522
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
523
+ if vlm_type == "llava-next":
524
+ if vlm_processor != "" and vlm_model != "":
525
+ vlm_model.to(device)
526
+ return vlm_model_dropdown
527
+ else:
528
+ if os.path.exists(vlm_local_path):
529
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
530
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
531
+ else:
532
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
533
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
534
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
535
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
536
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
537
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
538
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
539
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
540
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
541
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
542
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
543
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
544
+ elif vlm_name == "llava-next-72b-hf (Preload)":
545
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
546
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
547
+ elif vlm_type == "qwen2-vl":
548
+ if vlm_processor != "" and vlm_model != "":
549
+ vlm_model.to(device)
550
+ return vlm_model_dropdown
551
+ else:
552
+ if os.path.exists(vlm_local_path):
553
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
554
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
555
+ else:
556
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
557
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
558
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
559
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
560
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
561
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
562
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
563
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
564
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
565
+ elif vlm_type == "openai":
566
+ pass
567
+ return "success"
568
+
569
+
570
+ def update_base_model(base_model_name):
571
+ global pipe
572
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
573
+ if pipe is not None:
574
+ del pipe
575
+ torch.cuda.empty_cache()
576
+ base_model_path, pipe = base_models_template[base_model_name]
577
+ if pipe != "":
578
+ pipe.to(device)
579
+ else:
580
+ if os.path.exists(base_model_path):
581
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
582
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
583
+ )
584
+ # pipe.enable_xformers_memory_efficient_attention()
585
+ pipe.enable_model_cpu_offload()
586
+ else:
587
+ raise gr.Error(f"The base model {base_model_name} does not exist")
588
+ return "success"
589
+
590
+
591
+ def submit_GPT4o_KEY(GPT4o_KEY):
592
+ global vlm_model, vlm_processor
593
+ if vlm_model is not None:
594
+ del vlm_model
595
+ torch.cuda.empty_cache()
596
+ try:
597
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
598
+ # vlm_model = OpenAI(api_key="sk-d145b963a92649a88843caeb741e8bbc", base_url="https://api.deepseek.com")
599
+ vlm_processor = ""
600
+ response = vlm_model.chat.completions.create(
601
+ model="deepseek-chat",
602
+ messages=[
603
+ {"role": "system", "content": "You are a helpful assistant."},
604
+ {"role": "user", "content": "Hello."}
605
+ ]
606
+ )
607
+ response_str = response.choices[0].message.content
608
+
609
+ return "Success, " + response_str, "GPT4-o (Highly Recommended)"
610
+ except Exception as e:
611
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
612
+
613
+
614
+
615
+ def process(input_image,
616
+ original_image,
617
+ original_mask,
618
+ prompt,
619
+ negative_prompt,
620
+ control_strength,
621
+ seed,
622
+ randomize_seed,
623
+ guidance_scale,
624
+ num_inference_steps,
625
+ num_samples,
626
+ blending,
627
+ category,
628
+ target_prompt,
629
+ resize_default,
630
+ aspect_ratio_name,
631
+ invert_mask_state):
632
+ if original_image is None:
633
+ if input_image is None:
634
+ raise gr.Error('Please upload the input image')
635
+ else:
636
+ image_pil = input_image["background"].convert("RGB")
637
+ original_image = np.array(image_pil)
638
+ if prompt is None or prompt == "":
639
+ if target_prompt is None or target_prompt == "":
640
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
641
+
642
+ alpha_mask = input_image["layers"][0].split()[3]
643
+ input_mask = np.asarray(alpha_mask)
644
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
645
+ if output_w == "" or output_h == "":
646
+ output_h, output_w = original_image.shape[:2]
647
+
648
+ if resize_default:
649
+ short_side = min(output_w, output_h)
650
+ scale_ratio = 640 / short_side
651
+ output_w = int(output_w * scale_ratio)
652
+ output_h = int(output_h * scale_ratio)
653
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
654
+ original_image = np.array(original_image)
655
+ if input_mask is not None:
656
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
657
+ input_mask = np.array(input_mask)
658
+ if original_mask is not None:
659
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
660
+ original_mask = np.array(original_mask)
661
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
662
+ else:
663
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
664
+ pass
665
+ else:
666
+ if resize_default:
667
+ short_side = min(output_w, output_h)
668
+ scale_ratio = 640 / short_side
669
+ output_w = int(output_w * scale_ratio)
670
+ output_h = int(output_h * scale_ratio)
671
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
672
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
673
+ original_image = np.array(original_image)
674
+ if input_mask is not None:
675
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
676
+ input_mask = np.array(input_mask)
677
+ if original_mask is not None:
678
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
679
+ original_mask = np.array(original_mask)
680
+
681
+ if invert_mask_state:
682
+ original_mask = original_mask
683
+ else:
684
+ if input_mask.max() == 0:
685
+ original_mask = original_mask
686
+ else:
687
+ original_mask = input_mask
688
+
689
+
690
+ ## inpainting directly if target_prompt is not None
691
+ if category is not None:
692
+ pass
693
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
694
+ pass
695
+ else:
696
+ try:
697
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
698
+ except Exception as e:
699
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
700
+
701
+
702
+ if original_mask is not None:
703
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
704
+ else:
705
+ try:
706
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
707
+ vlm_processor,
708
+ vlm_model,
709
+ original_image,
710
+ category,
711
+ prompt,
712
+ device)
713
+
714
+ original_mask = vlm_response_mask(vlm_processor,
715
+ vlm_model,
716
+ category,
717
+ original_image,
718
+ prompt,
719
+ object_wait_for_edit,
720
+ sam,
721
+ sam_predictor,
722
+ sam_automask_generator,
723
+ groundingdino_model,
724
+ device).astype(np.uint8)
725
+ except Exception as e:
726
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
727
+
728
+ if original_mask.ndim == 2:
729
+ original_mask = original_mask[:,:,None]
730
+
731
+
732
+ if target_prompt is not None and len(target_prompt) >= 1:
733
+ prompt_after_apply_instruction = target_prompt
734
+
735
+ else:
736
+ try:
737
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
738
+ vlm_processor,
739
+ vlm_model,
740
+ original_image,
741
+ prompt,
742
+ device)
743
+ except Exception as e:
744
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
745
+
746
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
747
+
748
+
749
+ with torch.autocast(device):
750
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
751
+ prompt_after_apply_instruction,
752
+ original_mask,
753
+ original_image,
754
+ generator,
755
+ num_inference_steps,
756
+ guidance_scale,
757
+ control_strength,
758
+ negative_prompt,
759
+ num_samples,
760
+ blending)
761
+ original_image = np.array(init_image_np)
762
+ masked_image = original_image * (1 - (mask_np>0))
763
+ masked_image = masked_image.astype(np.uint8)
764
+ masked_image = Image.fromarray(masked_image)
765
+ # Save the images (optional)
766
+ # import uuid
767
+ # uuid = str(uuid.uuid4())
768
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
769
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
770
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
771
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
772
+ # mask_image.save(f"outputs/mask_{uuid}.png")
773
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
774
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
775
+ return image, [mask_image], [masked_image], prompt, '', False
776
+
777
+
778
+ def generate_target_prompt(input_image,
779
+ original_image,
780
+ prompt):
781
+ # load example image
782
+ if isinstance(original_image, str):
783
+ original_image = input_image
784
+
785
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
786
+ vlm_processor,
787
+ vlm_model,
788
+ original_image,
789
+ prompt,
790
+ device)
791
+ return prompt_after_apply_instruction
792
+
793
+
794
+ def process_mask(input_image,
795
+ original_image,
796
+ prompt,
797
+ resize_default,
798
+ aspect_ratio_name):
799
+ if original_image is None:
800
+ raise gr.Error('Please upload the input image')
801
+ if prompt is None:
802
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
803
+
804
+ ## load mask
805
+ alpha_mask = input_image["layers"][0].split()[3]
806
+ input_mask = np.array(alpha_mask)
807
+
808
+ # load example image
809
+ if isinstance(original_image, str):
810
+ original_image = input_image["background"]
811
+
812
+ if input_mask.max() == 0:
813
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
814
+
815
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
816
+ vlm_model,
817
+ original_image,
818
+ category,
819
+ prompt,
820
+ device)
821
+ # original mask: h,w,1 [0, 255]
822
+ original_mask = vlm_response_mask(
823
+ vlm_processor,
824
+ vlm_model,
825
+ category,
826
+ original_image,
827
+ prompt,
828
+ object_wait_for_edit,
829
+ sam,
830
+ sam_predictor,
831
+ sam_automask_generator,
832
+ groundingdino_model,
833
+ device).astype(np.uint8)
834
+ else:
835
+ original_mask = input_mask.astype(np.uint8)
836
+ category = None
837
+
838
+ ## resize mask if needed
839
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
840
+ if output_w == "" or output_h == "":
841
+ output_h, output_w = original_image.shape[:2]
842
+ if resize_default:
843
+ short_side = min(output_w, output_h)
844
+ scale_ratio = 640 / short_side
845
+ output_w = int(output_w * scale_ratio)
846
+ output_h = int(output_h * scale_ratio)
847
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
848
+ original_image = np.array(original_image)
849
+ if input_mask is not None:
850
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
851
+ input_mask = np.array(input_mask)
852
+ if original_mask is not None:
853
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
854
+ original_mask = np.array(original_mask)
855
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
856
+ else:
857
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
858
+ pass
859
+ else:
860
+ if resize_default:
861
+ short_side = min(output_w, output_h)
862
+ scale_ratio = 640 / short_side
863
+ output_w = int(output_w * scale_ratio)
864
+ output_h = int(output_h * scale_ratio)
865
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
866
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
867
+ original_image = np.array(original_image)
868
+ if input_mask is not None:
869
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
870
+ input_mask = np.array(input_mask)
871
+ if original_mask is not None:
872
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
873
+ original_mask = np.array(original_mask)
874
+
875
+
876
+ if original_mask.ndim == 2:
877
+ original_mask = original_mask[:,:,None]
878
+
879
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
880
+
881
+ masked_image = original_image * (1 - (original_mask>0))
882
+ masked_image = masked_image.astype(np.uint8)
883
+ masked_image = Image.fromarray(masked_image)
884
+
885
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
886
+
887
+
888
+ def process_random_mask(input_image,
889
+ original_image,
890
+ original_mask,
891
+ resize_default,
892
+ aspect_ratio_name,
893
+ ):
894
+
895
+ alpha_mask = input_image["layers"][0].split()[3]
896
+ input_mask = np.asarray(alpha_mask)
897
+
898
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
899
+ if output_w == "" or output_h == "":
900
+ output_h, output_w = original_image.shape[:2]
901
+ if resize_default:
902
+ short_side = min(output_w, output_h)
903
+ scale_ratio = 640 / short_side
904
+ output_w = int(output_w * scale_ratio)
905
+ output_h = int(output_h * scale_ratio)
906
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
907
+ original_image = np.array(original_image)
908
+ if input_mask is not None:
909
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
910
+ input_mask = np.array(input_mask)
911
+ if original_mask is not None:
912
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
913
+ original_mask = np.array(original_mask)
914
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
915
+ else:
916
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
917
+ pass
918
+ else:
919
+ if resize_default:
920
+ short_side = min(output_w, output_h)
921
+ scale_ratio = 640 / short_side
922
+ output_w = int(output_w * scale_ratio)
923
+ output_h = int(output_h * scale_ratio)
924
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
925
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
926
+ original_image = np.array(original_image)
927
+ if input_mask is not None:
928
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
929
+ input_mask = np.array(input_mask)
930
+ if original_mask is not None:
931
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
932
+ original_mask = np.array(original_mask)
933
+
934
+
935
+ if input_mask.max() == 0:
936
+ original_mask = original_mask
937
+ else:
938
+ original_mask = input_mask
939
+
940
+ if original_mask is None:
941
+ raise gr.Error('Please generate mask first')
942
+
943
+ if original_mask.ndim == 2:
944
+ original_mask = original_mask[:,:,None]
945
+
946
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
947
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
948
+
949
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
950
+
951
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
952
+ masked_image = masked_image.astype(original_image.dtype)
953
+ masked_image = Image.fromarray(masked_image)
954
+
955
+
956
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
957
+
958
+
959
+ def process_dilation_mask(input_image,
960
+ original_image,
961
+ original_mask,
962
+ resize_default,
963
+ aspect_ratio_name,
964
+ dilation_size=20):
965
+
966
+ alpha_mask = input_image["layers"][0].split()[3]
967
+ input_mask = np.asarray(alpha_mask)
968
+
969
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
970
+ if output_w == "" or output_h == "":
971
+ output_h, output_w = original_image.shape[:2]
972
+ if resize_default:
973
+ short_side = min(output_w, output_h)
974
+ scale_ratio = 640 / short_side
975
+ output_w = int(output_w * scale_ratio)
976
+ output_h = int(output_h * scale_ratio)
977
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
978
+ original_image = np.array(original_image)
979
+ if input_mask is not None:
980
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
981
+ input_mask = np.array(input_mask)
982
+ if original_mask is not None:
983
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
984
+ original_mask = np.array(original_mask)
985
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
986
+ else:
987
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
988
+ pass
989
+ else:
990
+ if resize_default:
991
+ short_side = min(output_w, output_h)
992
+ scale_ratio = 640 / short_side
993
+ output_w = int(output_w * scale_ratio)
994
+ output_h = int(output_h * scale_ratio)
995
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
996
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
997
+ original_image = np.array(original_image)
998
+ if input_mask is not None:
999
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1000
+ input_mask = np.array(input_mask)
1001
+ if original_mask is not None:
1002
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1003
+ original_mask = np.array(original_mask)
1004
+
1005
+ if input_mask.max() == 0:
1006
+ original_mask = original_mask
1007
+ else:
1008
+ original_mask = input_mask
1009
+
1010
+ if original_mask is None:
1011
+ raise gr.Error('Please generate mask first')
1012
+
1013
+ if original_mask.ndim == 2:
1014
+ original_mask = original_mask[:,:,None]
1015
+
1016
+ dilation_type = np.random.choice(['square_dilation'])
1017
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1018
+
1019
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1020
+
1021
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1022
+ masked_image = masked_image.astype(original_image.dtype)
1023
+ masked_image = Image.fromarray(masked_image)
1024
+
1025
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1026
+
1027
+
1028
+ def process_erosion_mask(input_image,
1029
+ original_image,
1030
+ original_mask,
1031
+ resize_default,
1032
+ aspect_ratio_name,
1033
+ dilation_size=20):
1034
+ alpha_mask = input_image["layers"][0].split()[3]
1035
+ input_mask = np.asarray(alpha_mask)
1036
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1037
+ if output_w == "" or output_h == "":
1038
+ output_h, output_w = original_image.shape[:2]
1039
+ if resize_default:
1040
+ short_side = min(output_w, output_h)
1041
+ scale_ratio = 640 / short_side
1042
+ output_w = int(output_w * scale_ratio)
1043
+ output_h = int(output_h * scale_ratio)
1044
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1045
+ original_image = np.array(original_image)
1046
+ if input_mask is not None:
1047
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1048
+ input_mask = np.array(input_mask)
1049
+ if original_mask is not None:
1050
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1051
+ original_mask = np.array(original_mask)
1052
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1053
+ else:
1054
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1055
+ pass
1056
+ else:
1057
+ if resize_default:
1058
+ short_side = min(output_w, output_h)
1059
+ scale_ratio = 640 / short_side
1060
+ output_w = int(output_w * scale_ratio)
1061
+ output_h = int(output_h * scale_ratio)
1062
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1063
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1064
+ original_image = np.array(original_image)
1065
+ if input_mask is not None:
1066
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1067
+ input_mask = np.array(input_mask)
1068
+ if original_mask is not None:
1069
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1070
+ original_mask = np.array(original_mask)
1071
+
1072
+ if input_mask.max() == 0:
1073
+ original_mask = original_mask
1074
+ else:
1075
+ original_mask = input_mask
1076
+
1077
+ if original_mask is None:
1078
+ raise gr.Error('Please generate mask first')
1079
+
1080
+ if original_mask.ndim == 2:
1081
+ original_mask = original_mask[:,:,None]
1082
+
1083
+ dilation_type = np.random.choice(['square_erosion'])
1084
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1085
+
1086
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1087
+
1088
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1089
+ masked_image = masked_image.astype(original_image.dtype)
1090
+ masked_image = Image.fromarray(masked_image)
1091
+
1092
+
1093
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1094
+
1095
+
1096
+ def move_mask_left(input_image,
1097
+ original_image,
1098
+ original_mask,
1099
+ moving_pixels,
1100
+ resize_default,
1101
+ aspect_ratio_name):
1102
+
1103
+ alpha_mask = input_image["layers"][0].split()[3]
1104
+ input_mask = np.asarray(alpha_mask)
1105
+
1106
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1107
+ if output_w == "" or output_h == "":
1108
+ output_h, output_w = original_image.shape[:2]
1109
+ if resize_default:
1110
+ short_side = min(output_w, output_h)
1111
+ scale_ratio = 640 / short_side
1112
+ output_w = int(output_w * scale_ratio)
1113
+ output_h = int(output_h * scale_ratio)
1114
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1115
+ original_image = np.array(original_image)
1116
+ if input_mask is not None:
1117
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1118
+ input_mask = np.array(input_mask)
1119
+ if original_mask is not None:
1120
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1121
+ original_mask = np.array(original_mask)
1122
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1123
+ else:
1124
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1125
+ pass
1126
+ else:
1127
+ if resize_default:
1128
+ short_side = min(output_w, output_h)
1129
+ scale_ratio = 640 / short_side
1130
+ output_w = int(output_w * scale_ratio)
1131
+ output_h = int(output_h * scale_ratio)
1132
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1133
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1134
+ original_image = np.array(original_image)
1135
+ if input_mask is not None:
1136
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1137
+ input_mask = np.array(input_mask)
1138
+ if original_mask is not None:
1139
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1140
+ original_mask = np.array(original_mask)
1141
+
1142
+ if input_mask.max() == 0:
1143
+ original_mask = original_mask
1144
+ else:
1145
+ original_mask = input_mask
1146
+
1147
+ if original_mask is None:
1148
+ raise gr.Error('Please generate mask first')
1149
+
1150
+ if original_mask.ndim == 2:
1151
+ original_mask = original_mask[:,:,None]
1152
+
1153
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1154
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1155
+
1156
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1157
+ masked_image = masked_image.astype(original_image.dtype)
1158
+ masked_image = Image.fromarray(masked_image)
1159
+
1160
+ if moved_mask.max() <= 1:
1161
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1162
+ original_mask = moved_mask
1163
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1164
+
1165
+
1166
+ def move_mask_right(input_image,
1167
+ original_image,
1168
+ original_mask,
1169
+ moving_pixels,
1170
+ resize_default,
1171
+ aspect_ratio_name):
1172
+ alpha_mask = input_image["layers"][0].split()[3]
1173
+ input_mask = np.asarray(alpha_mask)
1174
+
1175
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1176
+ if output_w == "" or output_h == "":
1177
+ output_h, output_w = original_image.shape[:2]
1178
+ if resize_default:
1179
+ short_side = min(output_w, output_h)
1180
+ scale_ratio = 640 / short_side
1181
+ output_w = int(output_w * scale_ratio)
1182
+ output_h = int(output_h * scale_ratio)
1183
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1184
+ original_image = np.array(original_image)
1185
+ if input_mask is not None:
1186
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1187
+ input_mask = np.array(input_mask)
1188
+ if original_mask is not None:
1189
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1190
+ original_mask = np.array(original_mask)
1191
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1192
+ else:
1193
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1194
+ pass
1195
+ else:
1196
+ if resize_default:
1197
+ short_side = min(output_w, output_h)
1198
+ scale_ratio = 640 / short_side
1199
+ output_w = int(output_w * scale_ratio)
1200
+ output_h = int(output_h * scale_ratio)
1201
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1202
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1203
+ original_image = np.array(original_image)
1204
+ if input_mask is not None:
1205
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1206
+ input_mask = np.array(input_mask)
1207
+ if original_mask is not None:
1208
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1209
+ original_mask = np.array(original_mask)
1210
+
1211
+ if input_mask.max() == 0:
1212
+ original_mask = original_mask
1213
+ else:
1214
+ original_mask = input_mask
1215
+
1216
+ if original_mask is None:
1217
+ raise gr.Error('Please generate mask first')
1218
+
1219
+ if original_mask.ndim == 2:
1220
+ original_mask = original_mask[:,:,None]
1221
+
1222
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1223
+
1224
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1225
+
1226
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1227
+ masked_image = masked_image.astype(original_image.dtype)
1228
+ masked_image = Image.fromarray(masked_image)
1229
+
1230
+
1231
+ if moved_mask.max() <= 1:
1232
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1233
+ original_mask = moved_mask
1234
+
1235
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1236
+
1237
+
1238
+ def move_mask_up(input_image,
1239
+ original_image,
1240
+ original_mask,
1241
+ moving_pixels,
1242
+ resize_default,
1243
+ aspect_ratio_name):
1244
+ alpha_mask = input_image["layers"][0].split()[3]
1245
+ input_mask = np.asarray(alpha_mask)
1246
+
1247
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1248
+ if output_w == "" or output_h == "":
1249
+ output_h, output_w = original_image.shape[:2]
1250
+ if resize_default:
1251
+ short_side = min(output_w, output_h)
1252
+ scale_ratio = 640 / short_side
1253
+ output_w = int(output_w * scale_ratio)
1254
+ output_h = int(output_h * scale_ratio)
1255
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1256
+ original_image = np.array(original_image)
1257
+ if input_mask is not None:
1258
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1259
+ input_mask = np.array(input_mask)
1260
+ if original_mask is not None:
1261
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1262
+ original_mask = np.array(original_mask)
1263
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1264
+ else:
1265
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1266
+ pass
1267
+ else:
1268
+ if resize_default:
1269
+ short_side = min(output_w, output_h)
1270
+ scale_ratio = 640 / short_side
1271
+ output_w = int(output_w * scale_ratio)
1272
+ output_h = int(output_h * scale_ratio)
1273
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1274
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1275
+ original_image = np.array(original_image)
1276
+ if input_mask is not None:
1277
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1278
+ input_mask = np.array(input_mask)
1279
+ if original_mask is not None:
1280
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1281
+ original_mask = np.array(original_mask)
1282
+
1283
+ if input_mask.max() == 0:
1284
+ original_mask = original_mask
1285
+ else:
1286
+ original_mask = input_mask
1287
+
1288
+ if original_mask is None:
1289
+ raise gr.Error('Please generate mask first')
1290
+
1291
+ if original_mask.ndim == 2:
1292
+ original_mask = original_mask[:,:,None]
1293
+
1294
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1295
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1296
+
1297
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1298
+ masked_image = masked_image.astype(original_image.dtype)
1299
+ masked_image = Image.fromarray(masked_image)
1300
+
1301
+ if moved_mask.max() <= 1:
1302
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1303
+ original_mask = moved_mask
1304
+
1305
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1306
+
1307
+
1308
+ def move_mask_down(input_image,
1309
+ original_image,
1310
+ original_mask,
1311
+ moving_pixels,
1312
+ resize_default,
1313
+ aspect_ratio_name):
1314
+ alpha_mask = input_image["layers"][0].split()[3]
1315
+ input_mask = np.asarray(alpha_mask)
1316
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1317
+ if output_w == "" or output_h == "":
1318
+ output_h, output_w = original_image.shape[:2]
1319
+ if resize_default:
1320
+ short_side = min(output_w, output_h)
1321
+ scale_ratio = 640 / short_side
1322
+ output_w = int(output_w * scale_ratio)
1323
+ output_h = int(output_h * scale_ratio)
1324
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1325
+ original_image = np.array(original_image)
1326
+ if input_mask is not None:
1327
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1328
+ input_mask = np.array(input_mask)
1329
+ if original_mask is not None:
1330
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1331
+ original_mask = np.array(original_mask)
1332
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1333
+ else:
1334
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1335
+ pass
1336
+ else:
1337
+ if resize_default:
1338
+ short_side = min(output_w, output_h)
1339
+ scale_ratio = 640 / short_side
1340
+ output_w = int(output_w * scale_ratio)
1341
+ output_h = int(output_h * scale_ratio)
1342
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1343
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1344
+ original_image = np.array(original_image)
1345
+ if input_mask is not None:
1346
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1347
+ input_mask = np.array(input_mask)
1348
+ if original_mask is not None:
1349
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1350
+ original_mask = np.array(original_mask)
1351
+
1352
+ if input_mask.max() == 0:
1353
+ original_mask = original_mask
1354
+ else:
1355
+ original_mask = input_mask
1356
+
1357
+ if original_mask is None:
1358
+ raise gr.Error('Please generate mask first')
1359
+
1360
+ if original_mask.ndim == 2:
1361
+ original_mask = original_mask[:,:,None]
1362
+
1363
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1364
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1365
+
1366
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1367
+ masked_image = masked_image.astype(original_image.dtype)
1368
+ masked_image = Image.fromarray(masked_image)
1369
+
1370
+ if moved_mask.max() <= 1:
1371
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1372
+ original_mask = moved_mask
1373
+
1374
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1375
+
1376
+
1377
+ def invert_mask(input_image,
1378
+ original_image,
1379
+ original_mask,
1380
+ ):
1381
+ alpha_mask = input_image["layers"][0].split()[3]
1382
+ input_mask = np.asarray(alpha_mask)
1383
+ if input_mask.max() == 0:
1384
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1385
+ else:
1386
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1387
+
1388
+ if original_mask is None:
1389
+ raise gr.Error('Please generate mask first')
1390
+
1391
+ original_mask = original_mask.squeeze()
1392
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1393
+
1394
+ if original_mask.ndim == 2:
1395
+ original_mask = original_mask[:,:,None]
1396
+
1397
+ if original_mask.max() <= 1:
1398
+ original_mask = (original_mask * 255).astype(np.uint8)
1399
+
1400
+ masked_image = original_image * (1 - (original_mask>0))
1401
+ masked_image = masked_image.astype(original_image.dtype)
1402
+ masked_image = Image.fromarray(masked_image)
1403
+
1404
+ return [masked_image], [mask_image], original_mask, True
1405
+
1406
+
1407
+ def init_img(base,
1408
+ init_type,
1409
+ prompt,
1410
+ aspect_ratio,
1411
+ example_change_times
1412
+ ):
1413
+ image_pil = base["background"].convert("RGB")
1414
+ original_image = np.array(image_pil)
1415
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1416
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1417
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1418
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1419
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1420
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1421
+ width, height = image_pil.size
1422
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1423
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1424
+ image_pil = image_pil.resize((width_new, height_new))
1425
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1426
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1427
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1428
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1429
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1430
+ else:
1431
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1432
+ aspect_ratio = "Custom resolution"
1433
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1434
+
1435
+
1436
+ def reset_func(input_image,
1437
+ original_image,
1438
+ original_mask,
1439
+ prompt,
1440
+ target_prompt,
1441
+ ):
1442
+ input_image = None
1443
+ original_image = None
1444
+ original_mask = None
1445
+ prompt = ''
1446
+ mask_gallery = []
1447
+ masked_gallery = []
1448
+ result_gallery = []
1449
+ target_prompt = ''
1450
+ if torch.cuda.is_available():
1451
+ torch.cuda.empty_cache()
1452
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1453
+
1454
+
1455
+ def update_example(example_type,
1456
+ prompt,
1457
+ example_change_times):
1458
+ input_image = INPUT_IMAGE_PATH[example_type]
1459
+ image_pil = Image.open(input_image).convert("RGB")
1460
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1461
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1462
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1463
+ width, height = image_pil.size
1464
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1465
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1466
+ image_pil = image_pil.resize((width_new, height_new))
1467
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1468
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1469
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1470
+
1471
+ original_image = np.array(image_pil)
1472
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1473
+ aspect_ratio = "Custom resolution"
1474
+ example_change_times += 1
1475
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1476
+
1477
+
1478
+ block = gr.Blocks(
1479
+ theme=gr.themes.Soft(
1480
+ radius_size=gr.themes.sizes.radius_none,
1481
+ text_size=gr.themes.sizes.text_md
1482
+ )
1483
+ )
1484
+ with block as demo:
1485
+ with gr.Row():
1486
+ with gr.Column():
1487
+ gr.HTML(head)
1488
+
1489
+ gr.Markdown(descriptions)
1490
+
1491
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1492
+ with gr.Row(equal_height=True):
1493
+ gr.Markdown(instructions)
1494
+
1495
+ original_image = gr.State(value=None)
1496
+ original_mask = gr.State(value=None)
1497
+ category = gr.State(value=None)
1498
+ status = gr.State(value=None)
1499
+ invert_mask_state = gr.State(value=False)
1500
+ example_change_times = gr.State(value=0)
1501
+
1502
+
1503
+ with gr.Row():
1504
+ with gr.Column():
1505
+ with gr.Row():
1506
+ input_image = gr.ImageEditor(
1507
+ label="Input Image",
1508
+ type="pil",
1509
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1510
+ layers = False,
1511
+ interactive=True,
1512
+ height=1024,
1513
+ sources=["upload"],
1514
+ placeholder="Please click here or the icon below to upload the image.",
1515
+ )
1516
+
1517
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1518
+ run_button = gr.Button("💫 Run")
1519
+
1520
+ vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1521
+ with gr.Group():
1522
+ with gr.Row():
1523
+ # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1524
+ # GPT4o_KEY = gr.Textbox(type="password", value="sk-d145b963a92649a88843caeb741e8bbc")
1525
+ GPT4o_KEY = gr.Textbox(label="GPT4o API Key", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1526
+
1527
+ GPT4o_KEY_submit = gr.Button("Submit and Verify")
1528
+
1529
+
1530
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1531
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1532
+
1533
+ with gr.Row():
1534
+ mask_button = gr.Button("Generate Mask")
1535
+ random_mask_button = gr.Button("Square/Circle Mask ")
1536
+
1537
+
1538
+ with gr.Row():
1539
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1540
+
1541
+ target_prompt = gr.Text(
1542
+ label="Input Target Prompt",
1543
+ max_lines=5,
1544
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1545
+ value='',
1546
+ lines=2
1547
+ )
1548
+
1549
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1550
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1551
+ negative_prompt = gr.Text(
1552
+ label="Negative Prompt",
1553
+ max_lines=5,
1554
+ placeholder="Please input your negative prompt",
1555
+ value='ugly, low quality',lines=1
1556
+ )
1557
+
1558
+ control_strength = gr.Slider(
1559
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1560
+ )
1561
+ with gr.Group():
1562
+ seed = gr.Slider(
1563
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1564
+ )
1565
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1566
+
1567
+ blending = gr.Checkbox(label="Blending mode", value=True)
1568
+
1569
+
1570
+ num_samples = gr.Slider(
1571
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1572
+ )
1573
+
1574
+ with gr.Group():
1575
+ with gr.Row():
1576
+ guidance_scale = gr.Slider(
1577
+ label="Guidance scale",
1578
+ minimum=1,
1579
+ maximum=12,
1580
+ step=0.1,
1581
+ value=7.5,
1582
+ )
1583
+ num_inference_steps = gr.Slider(
1584
+ label="Number of inference steps",
1585
+ minimum=1,
1586
+ maximum=50,
1587
+ step=1,
1588
+ value=50,
1589
+ )
1590
+
1591
+
1592
+ with gr.Column():
1593
+ with gr.Row():
1594
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1595
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1596
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1597
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1598
+
1599
+ invert_mask_button = gr.Button("Invert Mask")
1600
+ dilation_size = gr.Slider(
1601
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1602
+ )
1603
+ with gr.Row():
1604
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1605
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1606
+
1607
+ moving_pixels = gr.Slider(
1608
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1609
+ )
1610
+ with gr.Row():
1611
+ move_left_button = gr.Button("Move Left")
1612
+ move_right_button = gr.Button("Move Right")
1613
+ with gr.Row():
1614
+ move_up_button = gr.Button("Move Up")
1615
+ move_down_button = gr.Button("Move Down")
1616
+
1617
+ with gr.Tab(elem_classes="feedback", label="Output"):
1618
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1619
+
1620
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1621
+
1622
+ reset_button = gr.Button("Reset")
1623
+
1624
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1625
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1626
+
1627
+
1628
+
1629
+ with gr.Row():
1630
+ example = gr.Examples(
1631
+ label="Quick Example",
1632
+ examples=EXAMPLES,
1633
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1634
+ examples_per_page=10,
1635
+ cache_examples=False,
1636
+ )
1637
+
1638
+
1639
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1640
+ with gr.Row(equal_height=True):
1641
+ gr.Markdown(tips)
1642
+
1643
+ with gr.Row():
1644
+ gr.Markdown(citation)
1645
+
1646
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1647
+ ## And we need to solve the conflict between the upload and change example functions.
1648
+ input_image.upload(
1649
+ init_img,
1650
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1651
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1652
+ )
1653
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1654
+
1655
+ ## vlm and base model dropdown
1656
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1657
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1658
+
1659
+
1660
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1661
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1662
+
1663
+
1664
+ ips=[input_image,
1665
+ original_image,
1666
+ original_mask,
1667
+ prompt,
1668
+ negative_prompt,
1669
+ control_strength,
1670
+ seed,
1671
+ randomize_seed,
1672
+ guidance_scale,
1673
+ num_inference_steps,
1674
+ num_samples,
1675
+ blending,
1676
+ category,
1677
+ target_prompt,
1678
+ resize_default,
1679
+ aspect_ratio,
1680
+ invert_mask_state]
1681
+
1682
+ ## run brushedit
1683
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1684
+
1685
+ ## mask func
1686
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1687
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1688
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1689
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1690
+
1691
+ ## move mask func
1692
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1693
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1694
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1695
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1696
+
1697
+ ## prompt func
1698
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1699
+
1700
+ ## reset func
1701
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1702
+
1703
+ # if have a localhost access error, try to use the following code
1704
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1705
+ # demo.launch()
brushedit_app_315_0.py ADDED
@@ -0,0 +1,1696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ # base_openai_url = ""
39
+
40
+ #### Description ####
41
+ logo = r"""
42
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
43
+ """
44
+ head = r"""
45
+ <div style="text-align: center;">
46
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
47
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
48
+ <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
49
+ <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
50
+ <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
51
+
52
+ </div>
53
+ </br>
54
+ </div>
55
+ """
56
+ descriptions = r"""
57
+ Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
58
+ 🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
59
+ """
60
+
61
+ instructions = r"""
62
+ Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
63
+
64
+ 🛠️ <b>Fully automated instruction-based editing</b>:
65
+ <ul>
66
+ <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
67
+ <li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
68
+ <li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
69
+ </ul>
70
+
71
+ 🛠️ <b>Interactive instruction-based editing</b>:
72
+ <ul>
73
+ <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
74
+ <li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;"> to restore. </li>
75
+ <li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
76
+ <li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
77
+ </ul>
78
+
79
+ <b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
80
+
81
+ <b> We recommend zooming out in your browser for a better viewing range and experience. </b>
82
+
83
+ <b> For more detailed feature descriptions, see the bottom. </b>
84
+
85
+ ☕️ Have fun! 🎄 Wishing you a merry Christmas!
86
+ """
87
+
88
+ tips = r"""
89
+ 💡 <b>Some Tips</b>:
90
+ <ul>
91
+ <li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
92
+ <li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
93
+ <li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
94
+ </ul>
95
+
96
+ 💡 <b>Detailed Features</b>:
97
+ <ul>
98
+ <li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
99
+ <li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
100
+ <li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
101
+ <li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
102
+ <li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
103
+ <li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
104
+ <li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
105
+ <li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
106
+ <li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
107
+ <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
108
+ <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
109
+ </ul>
110
+
111
+ 💡 <b>Advanced Features</b>:
112
+ <ul>
113
+ <li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
114
+ <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
115
+ <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
116
+ <li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
117
+ <li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
118
+ <li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
119
+ </ul>
120
+
121
+
122
+ """
123
+
124
+
125
+
126
+ citation = r"""
127
+ If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
128
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/BrushEdit?style=social)](https://github.com/TencentARC/BrushEdit)
129
+ ---
130
+ 📝 **Citation**
131
+ <br>
132
+ If our work is useful for your research, please consider citing:
133
+ ```bibtex
134
+ @misc{li2024brushedit,
135
+ title={BrushEdit: All-In-One Image Inpainting and Editing},
136
+ author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
137
+ year={2024},
138
+ eprint={2412.10316},
139
+ archivePrefix={arXiv},
140
+ primaryClass={cs.CV}
141
+ }
142
+ ```
143
+ 📧 **Contact**
144
+ <br>
145
+ If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
146
+ """
147
+
148
+ # - - - - - examples - - - - - #
149
+ EXAMPLES = [
150
+
151
+ [
152
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
153
+ "add a magic hat on frog head.",
154
+ 642087011,
155
+ "frog",
156
+ "frog",
157
+ True,
158
+ False,
159
+ "GPT4-o (Highly Recommended)"
160
+ ],
161
+ [
162
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
163
+ "replace the background to ancient China.",
164
+ 648464818,
165
+ "chinese_girl",
166
+ "chinese_girl",
167
+ True,
168
+ False,
169
+ "GPT4-o (Highly Recommended)"
170
+ ],
171
+ [
172
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
173
+ "remove the deer.",
174
+ 648464818,
175
+ "angel_christmas",
176
+ "angel_christmas",
177
+ False,
178
+ False,
179
+ "GPT4-o (Highly Recommended)"
180
+ ],
181
+ [
182
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
183
+ "add a wreath on head.",
184
+ 648464818,
185
+ "sunflower_girl",
186
+ "sunflower_girl",
187
+ True,
188
+ False,
189
+ "GPT4-o (Highly Recommended)"
190
+ ],
191
+ [
192
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
193
+ "add a butterfly fairy.",
194
+ 648464818,
195
+ "girl_on_sun",
196
+ "girl_on_sun",
197
+ True,
198
+ False,
199
+ "GPT4-o (Highly Recommended)"
200
+ ],
201
+ [
202
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
203
+ "remove the christmas hat.",
204
+ 642087011,
205
+ "spider_man_rm",
206
+ "spider_man_rm",
207
+ False,
208
+ False,
209
+ "GPT4-o (Highly Recommended)"
210
+ ],
211
+ [
212
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
213
+ "remove the flower.",
214
+ 642087011,
215
+ "anime_flower",
216
+ "anime_flower",
217
+ False,
218
+ False,
219
+ "GPT4-o (Highly Recommended)"
220
+ ],
221
+ [
222
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
223
+ "replace the clothes to a delicated floral skirt.",
224
+ 648464818,
225
+ "chenduling",
226
+ "chenduling",
227
+ True,
228
+ False,
229
+ "GPT4-o (Highly Recommended)"
230
+ ],
231
+ [
232
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
233
+ "make the hedgehog in Italy.",
234
+ 648464818,
235
+ "hedgehog_rp_bg",
236
+ "hedgehog_rp_bg",
237
+ True,
238
+ False,
239
+ "GPT4-o (Highly Recommended)"
240
+ ],
241
+
242
+ ]
243
+
244
+ INPUT_IMAGE_PATH = {
245
+ "frog": "./assets/frog/frog.jpeg",
246
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
247
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
248
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
249
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
250
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
251
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
252
+ "chenduling": "./assets/chenduling/chengduling.jpg",
253
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
254
+ }
255
+ MASK_IMAGE_PATH = {
256
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
257
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
258
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
259
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
260
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
261
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
262
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
263
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
264
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
265
+ }
266
+ MASKED_IMAGE_PATH = {
267
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
268
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
269
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
270
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
271
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
272
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
273
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
274
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
275
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
276
+ }
277
+ OUTPUT_IMAGE_PATH = {
278
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
279
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
280
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
281
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
282
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
283
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
284
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
285
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
286
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
287
+ }
288
+
289
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
290
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
291
+
292
+ VLM_MODEL_NAMES = list(vlms_template.keys())
293
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
294
+ BASE_MODELS = list(base_models_template.keys())
295
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
296
+
297
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
298
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
299
+
300
+ device = "cuda"
301
+ torch_dtype = torch.bfloat16
302
+
303
+ ## init device
304
+ # try:
305
+ # if torch.cuda.is_available():
306
+ # device = "cuda"
307
+ # print("device = cuda")
308
+ # elif sys.platform == "darwin" and torch.backends.mps.is_available():
309
+ # device = "mps"
310
+ # print("device = mps")
311
+ # else:
312
+ # device = "cpu"
313
+ # print("device = cpu")
314
+ # except:
315
+ # device = "cpu"
316
+
317
+
318
+
319
+ # download hf models
320
+ BrushEdit_path = "models/"
321
+ if not os.path.exists(BrushEdit_path):
322
+ BrushEdit_path = snapshot_download(
323
+ repo_id="TencentARC/BrushEdit",
324
+ local_dir=BrushEdit_path,
325
+ token=os.getenv("HF_TOKEN"),
326
+ )
327
+
328
+ ## init default VLM
329
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
330
+ if vlm_processor != "" and vlm_model != "":
331
+ vlm_model.to(device)
332
+ else:
333
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
334
+
335
+
336
+ ## init base model
337
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
338
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
339
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
340
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
341
+
342
+
343
+ # input brushnetX ckpt path
344
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
345
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
346
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
347
+ )
348
+ # speed up diffusion process with faster scheduler and memory optimization
349
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
350
+ # remove following line if xformers is not installed or when using Torch 2.0.
351
+ # pipe.enable_xformers_memory_efficient_attention()
352
+ pipe.enable_model_cpu_offload()
353
+
354
+
355
+ ## init SAM
356
+ sam = build_sam(checkpoint=sam_path)
357
+ sam.to(device=device)
358
+ sam_predictor = SamPredictor(sam)
359
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
360
+
361
+ ## init groundingdino_model
362
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
363
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
364
+
365
+ ## Ordinary function
366
+ def crop_and_resize(image: Image.Image,
367
+ target_width: int,
368
+ target_height: int) -> Image.Image:
369
+ """
370
+ Crops and resizes an image while preserving the aspect ratio.
371
+
372
+ Args:
373
+ image (Image.Image): Input PIL image to be cropped and resized.
374
+ target_width (int): Target width of the output image.
375
+ target_height (int): Target height of the output image.
376
+
377
+ Returns:
378
+ Image.Image: Cropped and resized image.
379
+ """
380
+ # Original dimensions
381
+ original_width, original_height = image.size
382
+ original_aspect = original_width / original_height
383
+ target_aspect = target_width / target_height
384
+
385
+ # Calculate crop box to maintain aspect ratio
386
+ if original_aspect > target_aspect:
387
+ # Crop horizontally
388
+ new_width = int(original_height * target_aspect)
389
+ new_height = original_height
390
+ left = (original_width - new_width) / 2
391
+ top = 0
392
+ right = left + new_width
393
+ bottom = original_height
394
+ else:
395
+ # Crop vertically
396
+ new_width = original_width
397
+ new_height = int(original_width / target_aspect)
398
+ left = 0
399
+ top = (original_height - new_height) / 2
400
+ right = original_width
401
+ bottom = top + new_height
402
+
403
+ # Crop and resize
404
+ cropped_image = image.crop((left, top, right, bottom))
405
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
406
+ return resized_image
407
+
408
+
409
+ ## Ordinary function
410
+ def resize(image: Image.Image,
411
+ target_width: int,
412
+ target_height: int) -> Image.Image:
413
+ """
414
+ Crops and resizes an image while preserving the aspect ratio.
415
+
416
+ Args:
417
+ image (Image.Image): Input PIL image to be cropped and resized.
418
+ target_width (int): Target width of the output image.
419
+ target_height (int): Target height of the output image.
420
+
421
+ Returns:
422
+ Image.Image: Cropped and resized image.
423
+ """
424
+ # Original dimensions
425
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
426
+ return resized_image
427
+
428
+
429
+ def move_mask_func(mask, direction, units):
430
+ binary_mask = mask.squeeze()>0
431
+ rows, cols = binary_mask.shape
432
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
433
+
434
+ if direction == 'down':
435
+ # move down
436
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
437
+
438
+ elif direction == 'up':
439
+ # move up
440
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
441
+
442
+ elif direction == 'right':
443
+ # move left
444
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
445
+
446
+ elif direction == 'left':
447
+ # move right
448
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
449
+
450
+ return moved_mask
451
+
452
+
453
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
454
+ # Randomly select the size of dilation
455
+ binary_mask = mask.squeeze()>0
456
+
457
+ if dilation_type == 'square_dilation':
458
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
459
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
460
+ elif dilation_type == 'square_erosion':
461
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
462
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
463
+ elif dilation_type == 'bounding_box':
464
+ # find the most left top and left bottom point
465
+ rows, cols = np.where(binary_mask)
466
+ if len(rows) == 0 or len(cols) == 0:
467
+ return mask # return original mask if no valid points
468
+
469
+ min_row = np.min(rows)
470
+ max_row = np.max(rows)
471
+ min_col = np.min(cols)
472
+ max_col = np.max(cols)
473
+
474
+ # create a bounding box
475
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
476
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
477
+
478
+ elif dilation_type == 'bounding_ellipse':
479
+ # find the most left top and left bottom point
480
+ rows, cols = np.where(binary_mask)
481
+ if len(rows) == 0 or len(cols) == 0:
482
+ return mask # return original mask if no valid points
483
+
484
+ min_row = np.min(rows)
485
+ max_row = np.max(rows)
486
+ min_col = np.min(cols)
487
+ max_col = np.max(cols)
488
+
489
+ # calculate the center and axis length of the ellipse
490
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
491
+ a = (max_col - min_col) // 2 # half long axis
492
+ b = (max_row - min_row) // 2 # half short axis
493
+
494
+ # create a bounding ellipse
495
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
496
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
497
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
498
+ dilated_mask[ellipse_mask] = True
499
+ else:
500
+ ValueError("dilation_type must be 'square' or 'ellipse'")
501
+
502
+ # use binary dilation
503
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
504
+ return dilated_mask
505
+
506
+
507
+ ## Gradio component function
508
+ def update_vlm_model(vlm_name):
509
+ global vlm_model, vlm_processor
510
+ if vlm_model is not None:
511
+ del vlm_model
512
+ torch.cuda.empty_cache()
513
+
514
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
515
+
516
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
517
+ if vlm_type == "llava-next":
518
+ if vlm_processor != "" and vlm_model != "":
519
+ vlm_model.to(device)
520
+ return vlm_model_dropdown
521
+ else:
522
+ if os.path.exists(vlm_local_path):
523
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
524
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
525
+ else:
526
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
527
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
528
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
529
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
530
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
531
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
532
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
533
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
534
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
535
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
536
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
537
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
538
+ elif vlm_name == "llava-next-72b-hf (Preload)":
539
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
540
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
541
+ elif vlm_type == "qwen2-vl":
542
+ if vlm_processor != "" and vlm_model != "":
543
+ vlm_model.to(device)
544
+ return vlm_model_dropdown
545
+ else:
546
+ if os.path.exists(vlm_local_path):
547
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
548
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
549
+ else:
550
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
551
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
552
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
553
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
554
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
555
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
556
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
557
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
558
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
559
+ elif vlm_type == "openai":
560
+ pass
561
+ return "success"
562
+
563
+
564
+ def update_base_model(base_model_name):
565
+ global pipe
566
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
567
+ if pipe is not None:
568
+ del pipe
569
+ torch.cuda.empty_cache()
570
+ base_model_path, pipe = base_models_template[base_model_name]
571
+ if pipe != "":
572
+ pipe.to(device)
573
+ else:
574
+ if os.path.exists(base_model_path):
575
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
576
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
577
+ )
578
+ # pipe.enable_xformers_memory_efficient_attention()
579
+ pipe.enable_model_cpu_offload()
580
+ else:
581
+ raise gr.Error(f"The base model {base_model_name} does not exist")
582
+ return "success"
583
+
584
+
585
+ def submit_GPT4o_KEY(GPT4o_KEY):
586
+ global vlm_model, vlm_processor
587
+ if vlm_model is not None:
588
+ del vlm_model
589
+ torch.cuda.empty_cache()
590
+ try:
591
+ vlm_model = OpenAI(api_key=GPT4o_KEY)
592
+ vlm_processor = ""
593
+ response = vlm_model.chat.completions.create(
594
+ model="gpt-4o-2024-08-06",
595
+ messages=[
596
+ {"role": "system", "content": "You are a helpful assistant."},
597
+ {"role": "user", "content": "Say this is a test"}
598
+ ]
599
+ )
600
+ response_str = response.choices[0].message.content
601
+
602
+ return "Success, " + response_str, "GPT4-o (Highly Recommended)"
603
+ except Exception as e:
604
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
605
+
606
+
607
+
608
+ def process(input_image,
609
+ original_image,
610
+ original_mask,
611
+ prompt,
612
+ negative_prompt,
613
+ control_strength,
614
+ seed,
615
+ randomize_seed,
616
+ guidance_scale,
617
+ num_inference_steps,
618
+ num_samples,
619
+ blending,
620
+ category,
621
+ target_prompt,
622
+ resize_default,
623
+ aspect_ratio_name,
624
+ invert_mask_state):
625
+ if original_image is None:
626
+ if input_image is None:
627
+ raise gr.Error('Please upload the input image')
628
+ else:
629
+ image_pil = input_image["background"].convert("RGB")
630
+ original_image = np.array(image_pil)
631
+ if prompt is None or prompt == "":
632
+ if target_prompt is None or target_prompt == "":
633
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
634
+
635
+ alpha_mask = input_image["layers"][0].split()[3]
636
+ input_mask = np.asarray(alpha_mask)
637
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
638
+ if output_w == "" or output_h == "":
639
+ output_h, output_w = original_image.shape[:2]
640
+
641
+ if resize_default:
642
+ short_side = min(output_w, output_h)
643
+ scale_ratio = 640 / short_side
644
+ output_w = int(output_w * scale_ratio)
645
+ output_h = int(output_h * scale_ratio)
646
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
647
+ original_image = np.array(original_image)
648
+ if input_mask is not None:
649
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
650
+ input_mask = np.array(input_mask)
651
+ if original_mask is not None:
652
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
653
+ original_mask = np.array(original_mask)
654
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
655
+ else:
656
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
657
+ pass
658
+ else:
659
+ if resize_default:
660
+ short_side = min(output_w, output_h)
661
+ scale_ratio = 640 / short_side
662
+ output_w = int(output_w * scale_ratio)
663
+ output_h = int(output_h * scale_ratio)
664
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
665
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
666
+ original_image = np.array(original_image)
667
+ if input_mask is not None:
668
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
669
+ input_mask = np.array(input_mask)
670
+ if original_mask is not None:
671
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
672
+ original_mask = np.array(original_mask)
673
+
674
+ if invert_mask_state:
675
+ original_mask = original_mask
676
+ else:
677
+ if input_mask.max() == 0:
678
+ original_mask = original_mask
679
+ else:
680
+ original_mask = input_mask
681
+
682
+
683
+ ## inpainting directly if target_prompt is not None
684
+ if category is not None:
685
+ pass
686
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
687
+ pass
688
+ else:
689
+ try:
690
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
691
+ except Exception as e:
692
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
693
+
694
+
695
+ if original_mask is not None:
696
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
697
+ else:
698
+ try:
699
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
700
+ vlm_processor,
701
+ vlm_model,
702
+ original_image,
703
+ category,
704
+ prompt,
705
+ device)
706
+
707
+ original_mask = vlm_response_mask(vlm_processor,
708
+ vlm_model,
709
+ category,
710
+ original_image,
711
+ prompt,
712
+ object_wait_for_edit,
713
+ sam,
714
+ sam_predictor,
715
+ sam_automask_generator,
716
+ groundingdino_model,
717
+ device).astype(np.uint8)
718
+ except Exception as e:
719
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
720
+
721
+ if original_mask.ndim == 2:
722
+ original_mask = original_mask[:,:,None]
723
+
724
+
725
+ if target_prompt is not None and len(target_prompt) >= 1:
726
+ prompt_after_apply_instruction = target_prompt
727
+
728
+ else:
729
+ try:
730
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
731
+ vlm_processor,
732
+ vlm_model,
733
+ original_image,
734
+ prompt,
735
+ device)
736
+ except Exception as e:
737
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
738
+
739
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
740
+
741
+
742
+ with torch.autocast(device):
743
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
744
+ prompt_after_apply_instruction,
745
+ original_mask,
746
+ original_image,
747
+ generator,
748
+ num_inference_steps,
749
+ guidance_scale,
750
+ control_strength,
751
+ negative_prompt,
752
+ num_samples,
753
+ blending)
754
+ original_image = np.array(init_image_np)
755
+ masked_image = original_image * (1 - (mask_np>0))
756
+ masked_image = masked_image.astype(np.uint8)
757
+ masked_image = Image.fromarray(masked_image)
758
+ # Save the images (optional)
759
+ # import uuid
760
+ # uuid = str(uuid.uuid4())
761
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
762
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
763
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
764
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
765
+ # mask_image.save(f"outputs/mask_{uuid}.png")
766
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
767
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
768
+ return image, [mask_image], [masked_image], prompt, '', False
769
+
770
+
771
+ def generate_target_prompt(input_image,
772
+ original_image,
773
+ prompt):
774
+ # load example image
775
+ if isinstance(original_image, str):
776
+ original_image = input_image
777
+
778
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
779
+ vlm_processor,
780
+ vlm_model,
781
+ original_image,
782
+ prompt,
783
+ device)
784
+ return prompt_after_apply_instruction
785
+
786
+
787
+ def process_mask(input_image,
788
+ original_image,
789
+ prompt,
790
+ resize_default,
791
+ aspect_ratio_name):
792
+ if original_image is None:
793
+ raise gr.Error('Please upload the input image')
794
+ if prompt is None:
795
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
796
+
797
+ ## load mask
798
+ alpha_mask = input_image["layers"][0].split()[3]
799
+ input_mask = np.array(alpha_mask)
800
+
801
+ # load example image
802
+ if isinstance(original_image, str):
803
+ original_image = input_image["background"]
804
+
805
+ if input_mask.max() == 0:
806
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
807
+
808
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
809
+ vlm_model,
810
+ original_image,
811
+ category,
812
+ prompt,
813
+ device)
814
+ # original mask: h,w,1 [0, 255]
815
+ original_mask = vlm_response_mask(
816
+ vlm_processor,
817
+ vlm_model,
818
+ category,
819
+ original_image,
820
+ prompt,
821
+ object_wait_for_edit,
822
+ sam,
823
+ sam_predictor,
824
+ sam_automask_generator,
825
+ groundingdino_model,
826
+ device).astype(np.uint8)
827
+ else:
828
+ original_mask = input_mask.astype(np.uint8)
829
+ category = None
830
+
831
+ ## resize mask if needed
832
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
833
+ if output_w == "" or output_h == "":
834
+ output_h, output_w = original_image.shape[:2]
835
+ if resize_default:
836
+ short_side = min(output_w, output_h)
837
+ scale_ratio = 640 / short_side
838
+ output_w = int(output_w * scale_ratio)
839
+ output_h = int(output_h * scale_ratio)
840
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
841
+ original_image = np.array(original_image)
842
+ if input_mask is not None:
843
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
844
+ input_mask = np.array(input_mask)
845
+ if original_mask is not None:
846
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
847
+ original_mask = np.array(original_mask)
848
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
849
+ else:
850
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
851
+ pass
852
+ else:
853
+ if resize_default:
854
+ short_side = min(output_w, output_h)
855
+ scale_ratio = 640 / short_side
856
+ output_w = int(output_w * scale_ratio)
857
+ output_h = int(output_h * scale_ratio)
858
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
859
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
860
+ original_image = np.array(original_image)
861
+ if input_mask is not None:
862
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
863
+ input_mask = np.array(input_mask)
864
+ if original_mask is not None:
865
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
866
+ original_mask = np.array(original_mask)
867
+
868
+
869
+ if original_mask.ndim == 2:
870
+ original_mask = original_mask[:,:,None]
871
+
872
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
873
+
874
+ masked_image = original_image * (1 - (original_mask>0))
875
+ masked_image = masked_image.astype(np.uint8)
876
+ masked_image = Image.fromarray(masked_image)
877
+
878
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
879
+
880
+
881
+ def process_random_mask(input_image,
882
+ original_image,
883
+ original_mask,
884
+ resize_default,
885
+ aspect_ratio_name,
886
+ ):
887
+
888
+ alpha_mask = input_image["layers"][0].split()[3]
889
+ input_mask = np.asarray(alpha_mask)
890
+
891
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
892
+ if output_w == "" or output_h == "":
893
+ output_h, output_w = original_image.shape[:2]
894
+ if resize_default:
895
+ short_side = min(output_w, output_h)
896
+ scale_ratio = 640 / short_side
897
+ output_w = int(output_w * scale_ratio)
898
+ output_h = int(output_h * scale_ratio)
899
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
900
+ original_image = np.array(original_image)
901
+ if input_mask is not None:
902
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
903
+ input_mask = np.array(input_mask)
904
+ if original_mask is not None:
905
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
906
+ original_mask = np.array(original_mask)
907
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
908
+ else:
909
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
910
+ pass
911
+ else:
912
+ if resize_default:
913
+ short_side = min(output_w, output_h)
914
+ scale_ratio = 640 / short_side
915
+ output_w = int(output_w * scale_ratio)
916
+ output_h = int(output_h * scale_ratio)
917
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
918
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
919
+ original_image = np.array(original_image)
920
+ if input_mask is not None:
921
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
922
+ input_mask = np.array(input_mask)
923
+ if original_mask is not None:
924
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
925
+ original_mask = np.array(original_mask)
926
+
927
+
928
+ if input_mask.max() == 0:
929
+ original_mask = original_mask
930
+ else:
931
+ original_mask = input_mask
932
+
933
+ if original_mask is None:
934
+ raise gr.Error('Please generate mask first')
935
+
936
+ if original_mask.ndim == 2:
937
+ original_mask = original_mask[:,:,None]
938
+
939
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
940
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
941
+
942
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
943
+
944
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
945
+ masked_image = masked_image.astype(original_image.dtype)
946
+ masked_image = Image.fromarray(masked_image)
947
+
948
+
949
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
950
+
951
+
952
+ def process_dilation_mask(input_image,
953
+ original_image,
954
+ original_mask,
955
+ resize_default,
956
+ aspect_ratio_name,
957
+ dilation_size=20):
958
+
959
+ alpha_mask = input_image["layers"][0].split()[3]
960
+ input_mask = np.asarray(alpha_mask)
961
+
962
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
963
+ if output_w == "" or output_h == "":
964
+ output_h, output_w = original_image.shape[:2]
965
+ if resize_default:
966
+ short_side = min(output_w, output_h)
967
+ scale_ratio = 640 / short_side
968
+ output_w = int(output_w * scale_ratio)
969
+ output_h = int(output_h * scale_ratio)
970
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
971
+ original_image = np.array(original_image)
972
+ if input_mask is not None:
973
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
974
+ input_mask = np.array(input_mask)
975
+ if original_mask is not None:
976
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
977
+ original_mask = np.array(original_mask)
978
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
979
+ else:
980
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
981
+ pass
982
+ else:
983
+ if resize_default:
984
+ short_side = min(output_w, output_h)
985
+ scale_ratio = 640 / short_side
986
+ output_w = int(output_w * scale_ratio)
987
+ output_h = int(output_h * scale_ratio)
988
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
989
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
990
+ original_image = np.array(original_image)
991
+ if input_mask is not None:
992
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
993
+ input_mask = np.array(input_mask)
994
+ if original_mask is not None:
995
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
996
+ original_mask = np.array(original_mask)
997
+
998
+ if input_mask.max() == 0:
999
+ original_mask = original_mask
1000
+ else:
1001
+ original_mask = input_mask
1002
+
1003
+ if original_mask is None:
1004
+ raise gr.Error('Please generate mask first')
1005
+
1006
+ if original_mask.ndim == 2:
1007
+ original_mask = original_mask[:,:,None]
1008
+
1009
+ dilation_type = np.random.choice(['square_dilation'])
1010
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1011
+
1012
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1013
+
1014
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1015
+ masked_image = masked_image.astype(original_image.dtype)
1016
+ masked_image = Image.fromarray(masked_image)
1017
+
1018
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1019
+
1020
+
1021
+ def process_erosion_mask(input_image,
1022
+ original_image,
1023
+ original_mask,
1024
+ resize_default,
1025
+ aspect_ratio_name,
1026
+ dilation_size=20):
1027
+ alpha_mask = input_image["layers"][0].split()[3]
1028
+ input_mask = np.asarray(alpha_mask)
1029
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1030
+ if output_w == "" or output_h == "":
1031
+ output_h, output_w = original_image.shape[:2]
1032
+ if resize_default:
1033
+ short_side = min(output_w, output_h)
1034
+ scale_ratio = 640 / short_side
1035
+ output_w = int(output_w * scale_ratio)
1036
+ output_h = int(output_h * scale_ratio)
1037
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1038
+ original_image = np.array(original_image)
1039
+ if input_mask is not None:
1040
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1041
+ input_mask = np.array(input_mask)
1042
+ if original_mask is not None:
1043
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1044
+ original_mask = np.array(original_mask)
1045
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1046
+ else:
1047
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1048
+ pass
1049
+ else:
1050
+ if resize_default:
1051
+ short_side = min(output_w, output_h)
1052
+ scale_ratio = 640 / short_side
1053
+ output_w = int(output_w * scale_ratio)
1054
+ output_h = int(output_h * scale_ratio)
1055
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1056
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1057
+ original_image = np.array(original_image)
1058
+ if input_mask is not None:
1059
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1060
+ input_mask = np.array(input_mask)
1061
+ if original_mask is not None:
1062
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1063
+ original_mask = np.array(original_mask)
1064
+
1065
+ if input_mask.max() == 0:
1066
+ original_mask = original_mask
1067
+ else:
1068
+ original_mask = input_mask
1069
+
1070
+ if original_mask is None:
1071
+ raise gr.Error('Please generate mask first')
1072
+
1073
+ if original_mask.ndim == 2:
1074
+ original_mask = original_mask[:,:,None]
1075
+
1076
+ dilation_type = np.random.choice(['square_erosion'])
1077
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1078
+
1079
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1080
+
1081
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1082
+ masked_image = masked_image.astype(original_image.dtype)
1083
+ masked_image = Image.fromarray(masked_image)
1084
+
1085
+
1086
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1087
+
1088
+
1089
+ def move_mask_left(input_image,
1090
+ original_image,
1091
+ original_mask,
1092
+ moving_pixels,
1093
+ resize_default,
1094
+ aspect_ratio_name):
1095
+
1096
+ alpha_mask = input_image["layers"][0].split()[3]
1097
+ input_mask = np.asarray(alpha_mask)
1098
+
1099
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1100
+ if output_w == "" or output_h == "":
1101
+ output_h, output_w = original_image.shape[:2]
1102
+ if resize_default:
1103
+ short_side = min(output_w, output_h)
1104
+ scale_ratio = 640 / short_side
1105
+ output_w = int(output_w * scale_ratio)
1106
+ output_h = int(output_h * scale_ratio)
1107
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1108
+ original_image = np.array(original_image)
1109
+ if input_mask is not None:
1110
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1111
+ input_mask = np.array(input_mask)
1112
+ if original_mask is not None:
1113
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1114
+ original_mask = np.array(original_mask)
1115
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1116
+ else:
1117
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1118
+ pass
1119
+ else:
1120
+ if resize_default:
1121
+ short_side = min(output_w, output_h)
1122
+ scale_ratio = 640 / short_side
1123
+ output_w = int(output_w * scale_ratio)
1124
+ output_h = int(output_h * scale_ratio)
1125
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1126
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1127
+ original_image = np.array(original_image)
1128
+ if input_mask is not None:
1129
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1130
+ input_mask = np.array(input_mask)
1131
+ if original_mask is not None:
1132
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1133
+ original_mask = np.array(original_mask)
1134
+
1135
+ if input_mask.max() == 0:
1136
+ original_mask = original_mask
1137
+ else:
1138
+ original_mask = input_mask
1139
+
1140
+ if original_mask is None:
1141
+ raise gr.Error('Please generate mask first')
1142
+
1143
+ if original_mask.ndim == 2:
1144
+ original_mask = original_mask[:,:,None]
1145
+
1146
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1147
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1148
+
1149
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1150
+ masked_image = masked_image.astype(original_image.dtype)
1151
+ masked_image = Image.fromarray(masked_image)
1152
+
1153
+ if moved_mask.max() <= 1:
1154
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1155
+ original_mask = moved_mask
1156
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1157
+
1158
+
1159
+ def move_mask_right(input_image,
1160
+ original_image,
1161
+ original_mask,
1162
+ moving_pixels,
1163
+ resize_default,
1164
+ aspect_ratio_name):
1165
+ alpha_mask = input_image["layers"][0].split()[3]
1166
+ input_mask = np.asarray(alpha_mask)
1167
+
1168
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1169
+ if output_w == "" or output_h == "":
1170
+ output_h, output_w = original_image.shape[:2]
1171
+ if resize_default:
1172
+ short_side = min(output_w, output_h)
1173
+ scale_ratio = 640 / short_side
1174
+ output_w = int(output_w * scale_ratio)
1175
+ output_h = int(output_h * scale_ratio)
1176
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1177
+ original_image = np.array(original_image)
1178
+ if input_mask is not None:
1179
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1180
+ input_mask = np.array(input_mask)
1181
+ if original_mask is not None:
1182
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1183
+ original_mask = np.array(original_mask)
1184
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1185
+ else:
1186
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1187
+ pass
1188
+ else:
1189
+ if resize_default:
1190
+ short_side = min(output_w, output_h)
1191
+ scale_ratio = 640 / short_side
1192
+ output_w = int(output_w * scale_ratio)
1193
+ output_h = int(output_h * scale_ratio)
1194
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1195
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1196
+ original_image = np.array(original_image)
1197
+ if input_mask is not None:
1198
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1199
+ input_mask = np.array(input_mask)
1200
+ if original_mask is not None:
1201
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1202
+ original_mask = np.array(original_mask)
1203
+
1204
+ if input_mask.max() == 0:
1205
+ original_mask = original_mask
1206
+ else:
1207
+ original_mask = input_mask
1208
+
1209
+ if original_mask is None:
1210
+ raise gr.Error('Please generate mask first')
1211
+
1212
+ if original_mask.ndim == 2:
1213
+ original_mask = original_mask[:,:,None]
1214
+
1215
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1216
+
1217
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1218
+
1219
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1220
+ masked_image = masked_image.astype(original_image.dtype)
1221
+ masked_image = Image.fromarray(masked_image)
1222
+
1223
+
1224
+ if moved_mask.max() <= 1:
1225
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1226
+ original_mask = moved_mask
1227
+
1228
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1229
+
1230
+
1231
+ def move_mask_up(input_image,
1232
+ original_image,
1233
+ original_mask,
1234
+ moving_pixels,
1235
+ resize_default,
1236
+ aspect_ratio_name):
1237
+ alpha_mask = input_image["layers"][0].split()[3]
1238
+ input_mask = np.asarray(alpha_mask)
1239
+
1240
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1241
+ if output_w == "" or output_h == "":
1242
+ output_h, output_w = original_image.shape[:2]
1243
+ if resize_default:
1244
+ short_side = min(output_w, output_h)
1245
+ scale_ratio = 640 / short_side
1246
+ output_w = int(output_w * scale_ratio)
1247
+ output_h = int(output_h * scale_ratio)
1248
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1249
+ original_image = np.array(original_image)
1250
+ if input_mask is not None:
1251
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1252
+ input_mask = np.array(input_mask)
1253
+ if original_mask is not None:
1254
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1255
+ original_mask = np.array(original_mask)
1256
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1257
+ else:
1258
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1259
+ pass
1260
+ else:
1261
+ if resize_default:
1262
+ short_side = min(output_w, output_h)
1263
+ scale_ratio = 640 / short_side
1264
+ output_w = int(output_w * scale_ratio)
1265
+ output_h = int(output_h * scale_ratio)
1266
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1267
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1268
+ original_image = np.array(original_image)
1269
+ if input_mask is not None:
1270
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1271
+ input_mask = np.array(input_mask)
1272
+ if original_mask is not None:
1273
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1274
+ original_mask = np.array(original_mask)
1275
+
1276
+ if input_mask.max() == 0:
1277
+ original_mask = original_mask
1278
+ else:
1279
+ original_mask = input_mask
1280
+
1281
+ if original_mask is None:
1282
+ raise gr.Error('Please generate mask first')
1283
+
1284
+ if original_mask.ndim == 2:
1285
+ original_mask = original_mask[:,:,None]
1286
+
1287
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1288
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1289
+
1290
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1291
+ masked_image = masked_image.astype(original_image.dtype)
1292
+ masked_image = Image.fromarray(masked_image)
1293
+
1294
+ if moved_mask.max() <= 1:
1295
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1296
+ original_mask = moved_mask
1297
+
1298
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1299
+
1300
+
1301
+ def move_mask_down(input_image,
1302
+ original_image,
1303
+ original_mask,
1304
+ moving_pixels,
1305
+ resize_default,
1306
+ aspect_ratio_name):
1307
+ alpha_mask = input_image["layers"][0].split()[3]
1308
+ input_mask = np.asarray(alpha_mask)
1309
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1310
+ if output_w == "" or output_h == "":
1311
+ output_h, output_w = original_image.shape[:2]
1312
+ if resize_default:
1313
+ short_side = min(output_w, output_h)
1314
+ scale_ratio = 640 / short_side
1315
+ output_w = int(output_w * scale_ratio)
1316
+ output_h = int(output_h * scale_ratio)
1317
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1318
+ original_image = np.array(original_image)
1319
+ if input_mask is not None:
1320
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1321
+ input_mask = np.array(input_mask)
1322
+ if original_mask is not None:
1323
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1324
+ original_mask = np.array(original_mask)
1325
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1326
+ else:
1327
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1328
+ pass
1329
+ else:
1330
+ if resize_default:
1331
+ short_side = min(output_w, output_h)
1332
+ scale_ratio = 640 / short_side
1333
+ output_w = int(output_w * scale_ratio)
1334
+ output_h = int(output_h * scale_ratio)
1335
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1336
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1337
+ original_image = np.array(original_image)
1338
+ if input_mask is not None:
1339
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1340
+ input_mask = np.array(input_mask)
1341
+ if original_mask is not None:
1342
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1343
+ original_mask = np.array(original_mask)
1344
+
1345
+ if input_mask.max() == 0:
1346
+ original_mask = original_mask
1347
+ else:
1348
+ original_mask = input_mask
1349
+
1350
+ if original_mask is None:
1351
+ raise gr.Error('Please generate mask first')
1352
+
1353
+ if original_mask.ndim == 2:
1354
+ original_mask = original_mask[:,:,None]
1355
+
1356
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1357
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1358
+
1359
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1360
+ masked_image = masked_image.astype(original_image.dtype)
1361
+ masked_image = Image.fromarray(masked_image)
1362
+
1363
+ if moved_mask.max() <= 1:
1364
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1365
+ original_mask = moved_mask
1366
+
1367
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1368
+
1369
+
1370
+ def invert_mask(input_image,
1371
+ original_image,
1372
+ original_mask,
1373
+ ):
1374
+ alpha_mask = input_image["layers"][0].split()[3]
1375
+ input_mask = np.asarray(alpha_mask)
1376
+ if input_mask.max() == 0:
1377
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1378
+ else:
1379
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1380
+
1381
+ if original_mask is None:
1382
+ raise gr.Error('Please generate mask first')
1383
+
1384
+ original_mask = original_mask.squeeze()
1385
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1386
+
1387
+ if original_mask.ndim == 2:
1388
+ original_mask = original_mask[:,:,None]
1389
+
1390
+ if original_mask.max() <= 1:
1391
+ original_mask = (original_mask * 255).astype(np.uint8)
1392
+
1393
+ masked_image = original_image * (1 - (original_mask>0))
1394
+ masked_image = masked_image.astype(original_image.dtype)
1395
+ masked_image = Image.fromarray(masked_image)
1396
+
1397
+ return [masked_image], [mask_image], original_mask, True
1398
+
1399
+
1400
+ def init_img(base,
1401
+ init_type,
1402
+ prompt,
1403
+ aspect_ratio,
1404
+ example_change_times
1405
+ ):
1406
+ image_pil = base["background"].convert("RGB")
1407
+ original_image = np.array(image_pil)
1408
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1409
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1410
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1411
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1412
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1413
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1414
+ width, height = image_pil.size
1415
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1416
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1417
+ image_pil = image_pil.resize((width_new, height_new))
1418
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1419
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1420
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1421
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1422
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1423
+ else:
1424
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1425
+ aspect_ratio = "Custom resolution"
1426
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1427
+
1428
+
1429
+ def reset_func(input_image,
1430
+ original_image,
1431
+ original_mask,
1432
+ prompt,
1433
+ target_prompt,
1434
+ ):
1435
+ input_image = None
1436
+ original_image = None
1437
+ original_mask = None
1438
+ prompt = ''
1439
+ mask_gallery = []
1440
+ masked_gallery = []
1441
+ result_gallery = []
1442
+ target_prompt = ''
1443
+ if torch.cuda.is_available():
1444
+ torch.cuda.empty_cache()
1445
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1446
+
1447
+
1448
+ def update_example(example_type,
1449
+ prompt,
1450
+ example_change_times):
1451
+ input_image = INPUT_IMAGE_PATH[example_type]
1452
+ image_pil = Image.open(input_image).convert("RGB")
1453
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1454
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1455
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1456
+ width, height = image_pil.size
1457
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1458
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1459
+ image_pil = image_pil.resize((width_new, height_new))
1460
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1461
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1462
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1463
+
1464
+ original_image = np.array(image_pil)
1465
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1466
+ aspect_ratio = "Custom resolution"
1467
+ example_change_times += 1
1468
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1469
+
1470
+
1471
+ block = gr.Blocks(
1472
+ theme=gr.themes.Soft(
1473
+ radius_size=gr.themes.sizes.radius_none,
1474
+ text_size=gr.themes.sizes.text_md
1475
+ )
1476
+ )
1477
+ with block as demo:
1478
+ with gr.Row():
1479
+ with gr.Column():
1480
+ gr.HTML(head)
1481
+
1482
+ gr.Markdown(descriptions)
1483
+
1484
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1485
+ with gr.Row(equal_height=True):
1486
+ gr.Markdown(instructions)
1487
+
1488
+ original_image = gr.State(value=None)
1489
+ original_mask = gr.State(value=None)
1490
+ category = gr.State(value=None)
1491
+ status = gr.State(value=None)
1492
+ invert_mask_state = gr.State(value=False)
1493
+ example_change_times = gr.State(value=0)
1494
+
1495
+
1496
+ with gr.Row():
1497
+ with gr.Column():
1498
+ with gr.Row():
1499
+ input_image = gr.ImageEditor(
1500
+ label="Input Image",
1501
+ type="pil",
1502
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1503
+ layers = False,
1504
+ interactive=True,
1505
+ height=1024,
1506
+ sources=["upload"],
1507
+ placeholder="Please click here or the icon below to upload the image.",
1508
+ )
1509
+
1510
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1511
+ run_button = gr.Button("💫 Run")
1512
+
1513
+ vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1514
+ with gr.Group():
1515
+ with gr.Row():
1516
+ GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1517
+
1518
+ GPT4o_KEY_submit = gr.Button("Submit and Verify")
1519
+
1520
+
1521
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1522
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1523
+
1524
+ with gr.Row():
1525
+ mask_button = gr.Button("Generate Mask")
1526
+ random_mask_button = gr.Button("Square/Circle Mask ")
1527
+
1528
+
1529
+ with gr.Row():
1530
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1531
+
1532
+ target_prompt = gr.Text(
1533
+ label="Input Target Prompt",
1534
+ max_lines=5,
1535
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1536
+ value='',
1537
+ lines=2
1538
+ )
1539
+
1540
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1541
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1542
+ negative_prompt = gr.Text(
1543
+ label="Negative Prompt",
1544
+ max_lines=5,
1545
+ placeholder="Please input your negative prompt",
1546
+ value='ugly, low quality',lines=1
1547
+ )
1548
+
1549
+ control_strength = gr.Slider(
1550
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1551
+ )
1552
+ with gr.Group():
1553
+ seed = gr.Slider(
1554
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1555
+ )
1556
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1557
+
1558
+ blending = gr.Checkbox(label="Blending mode", value=True)
1559
+
1560
+
1561
+ num_samples = gr.Slider(
1562
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1563
+ )
1564
+
1565
+ with gr.Group():
1566
+ with gr.Row():
1567
+ guidance_scale = gr.Slider(
1568
+ label="Guidance scale",
1569
+ minimum=1,
1570
+ maximum=12,
1571
+ step=0.1,
1572
+ value=7.5,
1573
+ )
1574
+ num_inference_steps = gr.Slider(
1575
+ label="Number of inference steps",
1576
+ minimum=1,
1577
+ maximum=50,
1578
+ step=1,
1579
+ value=50,
1580
+ )
1581
+
1582
+
1583
+ with gr.Column():
1584
+ with gr.Row():
1585
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1586
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1587
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1588
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1589
+
1590
+ invert_mask_button = gr.Button("Invert Mask")
1591
+ dilation_size = gr.Slider(
1592
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1593
+ )
1594
+ with gr.Row():
1595
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1596
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1597
+
1598
+ moving_pixels = gr.Slider(
1599
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1600
+ )
1601
+ with gr.Row():
1602
+ move_left_button = gr.Button("Move Left")
1603
+ move_right_button = gr.Button("Move Right")
1604
+ with gr.Row():
1605
+ move_up_button = gr.Button("Move Up")
1606
+ move_down_button = gr.Button("Move Down")
1607
+
1608
+ with gr.Tab(elem_classes="feedback", label="Output"):
1609
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1610
+
1611
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1612
+
1613
+ reset_button = gr.Button("Reset")
1614
+
1615
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1616
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1617
+
1618
+
1619
+
1620
+ with gr.Row():
1621
+ example = gr.Examples(
1622
+ label="Quick Example",
1623
+ examples=EXAMPLES,
1624
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1625
+ examples_per_page=10,
1626
+ cache_examples=False,
1627
+ )
1628
+
1629
+
1630
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1631
+ with gr.Row(equal_height=True):
1632
+ gr.Markdown(tips)
1633
+
1634
+ with gr.Row():
1635
+ gr.Markdown(citation)
1636
+
1637
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1638
+ ## And we need to solve the conflict between the upload and change example functions.
1639
+ input_image.upload(
1640
+ init_img,
1641
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1642
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1643
+ )
1644
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1645
+
1646
+ ## vlm and base model dropdown
1647
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1648
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1649
+
1650
+
1651
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1652
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1653
+
1654
+
1655
+ ips=[input_image,
1656
+ original_image,
1657
+ original_mask,
1658
+ prompt,
1659
+ negative_prompt,
1660
+ control_strength,
1661
+ seed,
1662
+ randomize_seed,
1663
+ guidance_scale,
1664
+ num_inference_steps,
1665
+ num_samples,
1666
+ blending,
1667
+ category,
1668
+ target_prompt,
1669
+ resize_default,
1670
+ aspect_ratio,
1671
+ invert_mask_state]
1672
+
1673
+ ## run brushedit
1674
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1675
+
1676
+ ## mask func
1677
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1678
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1679
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1680
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1681
+
1682
+ ## move mask func
1683
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1684
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1685
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1686
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1687
+
1688
+ ## prompt func
1689
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1690
+
1691
+ ## reset func
1692
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1693
+
1694
+ ## if have a localhost access error, try to use the following code
1695
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1696
+ # demo.launch()
brushedit_app_315_1.py ADDED
@@ -0,0 +1,1624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ # base_openai_url = ""
39
+
40
+ #### Description ####
41
+ logo = r"""
42
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
43
+ """
44
+ head = r"""
45
+ <div style="text-align: center;">
46
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
47
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
48
+ <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
49
+ <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
50
+ <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
51
+
52
+ </div>
53
+ </br>
54
+ </div>
55
+ """
56
+ descriptions = r"""
57
+ Official Gradio Demo
58
+ """
59
+
60
+ instructions = r"""
61
+ 等待补充
62
+ """
63
+
64
+ tips = r"""
65
+ 等待补充
66
+
67
+ """
68
+
69
+
70
+
71
+ citation = r"""
72
+ 等待补充
73
+ """
74
+
75
+ # - - - - - examples - - - - - #
76
+ EXAMPLES = [
77
+
78
+ [
79
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
80
+ "add a magic hat on frog head.",
81
+ 642087011,
82
+ "frog",
83
+ "frog",
84
+ True,
85
+ False,
86
+ "GPT4-o (Highly Recommended)"
87
+ ],
88
+ [
89
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
90
+ "replace the background to ancient China.",
91
+ 648464818,
92
+ "chinese_girl",
93
+ "chinese_girl",
94
+ True,
95
+ False,
96
+ "GPT4-o (Highly Recommended)"
97
+ ],
98
+ [
99
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
100
+ "remove the deer.",
101
+ 648464818,
102
+ "angel_christmas",
103
+ "angel_christmas",
104
+ False,
105
+ False,
106
+ "GPT4-o (Highly Recommended)"
107
+ ],
108
+ [
109
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
110
+ "add a wreath on head.",
111
+ 648464818,
112
+ "sunflower_girl",
113
+ "sunflower_girl",
114
+ True,
115
+ False,
116
+ "GPT4-o (Highly Recommended)"
117
+ ],
118
+ [
119
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
120
+ "add a butterfly fairy.",
121
+ 648464818,
122
+ "girl_on_sun",
123
+ "girl_on_sun",
124
+ True,
125
+ False,
126
+ "GPT4-o (Highly Recommended)"
127
+ ],
128
+ [
129
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
130
+ "remove the christmas hat.",
131
+ 642087011,
132
+ "spider_man_rm",
133
+ "spider_man_rm",
134
+ False,
135
+ False,
136
+ "GPT4-o (Highly Recommended)"
137
+ ],
138
+ [
139
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
140
+ "remove the flower.",
141
+ 642087011,
142
+ "anime_flower",
143
+ "anime_flower",
144
+ False,
145
+ False,
146
+ "GPT4-o (Highly Recommended)"
147
+ ],
148
+ [
149
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
150
+ "replace the clothes to a delicated floral skirt.",
151
+ 648464818,
152
+ "chenduling",
153
+ "chenduling",
154
+ True,
155
+ False,
156
+ "GPT4-o (Highly Recommended)"
157
+ ],
158
+ [
159
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
160
+ "make the hedgehog in Italy.",
161
+ 648464818,
162
+ "hedgehog_rp_bg",
163
+ "hedgehog_rp_bg",
164
+ True,
165
+ False,
166
+ "GPT4-o (Highly Recommended)"
167
+ ],
168
+
169
+ ]
170
+
171
+ INPUT_IMAGE_PATH = {
172
+ "frog": "./assets/frog/frog.jpeg",
173
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
174
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
175
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
176
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
177
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
178
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
179
+ "chenduling": "./assets/chenduling/chengduling.jpg",
180
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
181
+ }
182
+ MASK_IMAGE_PATH = {
183
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
184
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
185
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
186
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
187
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
188
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
189
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
190
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
191
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
192
+ }
193
+ MASKED_IMAGE_PATH = {
194
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
195
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
196
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
197
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
198
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
199
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
200
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
201
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
202
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
203
+ }
204
+ OUTPUT_IMAGE_PATH = {
205
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
206
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
207
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
208
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
209
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
210
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
211
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
212
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
213
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
214
+ }
215
+
216
+
217
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
218
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
219
+
220
+ VLM_MODEL_NAMES = list(vlms_template.keys())
221
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
222
+ BASE_MODELS = list(base_models_template.keys())
223
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
224
+
225
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
226
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
227
+
228
+ device = "cuda:0"
229
+ torch_dtype = torch.bfloat16
230
+
231
+ ## init device
232
+ # try:
233
+ # if torch.cuda.is_available():
234
+ # device = "cuda"
235
+ # print("device = cuda")
236
+ # elif sys.platform == "darwin" and torch.backends.mps.is_available():
237
+ # device = "mps"
238
+ # print("device = mps")
239
+ # else:
240
+ # device = "cpu"
241
+ # print("device = cpu")
242
+ # except:
243
+ # device = "cpu"
244
+
245
+
246
+
247
+ # download hf models
248
+ BrushEdit_path = "models/"
249
+ if not os.path.exists(BrushEdit_path):
250
+ BrushEdit_path = snapshot_download(
251
+ repo_id="TencentARC/BrushEdit",
252
+ local_dir=BrushEdit_path,
253
+ token=os.getenv("HF_TOKEN"),
254
+ )
255
+
256
+ ## init default VLM
257
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
258
+ if vlm_processor != "" and vlm_model != "":
259
+ vlm_model.to(device)
260
+ else:
261
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
262
+
263
+
264
+ ## init base model
265
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
266
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
267
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
268
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
269
+
270
+
271
+ # input brushnetX ckpt path
272
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
273
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
274
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
275
+ )
276
+ # speed up diffusion process with faster scheduler and memory optimization
277
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
278
+ # remove following line if xformers is not installed or when using Torch 2.0.
279
+ # pipe.enable_xformers_memory_efficient_attention()
280
+ pipe.enable_model_cpu_offload()
281
+
282
+
283
+ ## init SAM
284
+ sam = build_sam(checkpoint=sam_path)
285
+ sam.to(device=device)
286
+ sam_predictor = SamPredictor(sam)
287
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
288
+
289
+ ## init groundingdino_model
290
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
291
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
292
+
293
+ ## Ordinary function
294
+ def crop_and_resize(image: Image.Image,
295
+ target_width: int,
296
+ target_height: int) -> Image.Image:
297
+ """
298
+ Crops and resizes an image while preserving the aspect ratio.
299
+
300
+ Args:
301
+ image (Image.Image): Input PIL image to be cropped and resized.
302
+ target_width (int): Target width of the output image.
303
+ target_height (int): Target height of the output image.
304
+
305
+ Returns:
306
+ Image.Image: Cropped and resized image.
307
+ """
308
+ # Original dimensions
309
+ original_width, original_height = image.size
310
+ original_aspect = original_width / original_height
311
+ target_aspect = target_width / target_height
312
+
313
+ # Calculate crop box to maintain aspect ratio
314
+ if original_aspect > target_aspect:
315
+ # Crop horizontally
316
+ new_width = int(original_height * target_aspect)
317
+ new_height = original_height
318
+ left = (original_width - new_width) / 2
319
+ top = 0
320
+ right = left + new_width
321
+ bottom = original_height
322
+ else:
323
+ # Crop vertically
324
+ new_width = original_width
325
+ new_height = int(original_width / target_aspect)
326
+ left = 0
327
+ top = (original_height - new_height) / 2
328
+ right = original_width
329
+ bottom = top + new_height
330
+
331
+ # Crop and resize
332
+ cropped_image = image.crop((left, top, right, bottom))
333
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
334
+ return resized_image
335
+
336
+
337
+ ## Ordinary function
338
+ def resize(image: Image.Image,
339
+ target_width: int,
340
+ target_height: int) -> Image.Image:
341
+ """
342
+ Crops and resizes an image while preserving the aspect ratio.
343
+
344
+ Args:
345
+ image (Image.Image): Input PIL image to be cropped and resized.
346
+ target_width (int): Target width of the output image.
347
+ target_height (int): Target height of the output image.
348
+
349
+ Returns:
350
+ Image.Image: Cropped and resized image.
351
+ """
352
+ # Original dimensions
353
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
354
+ return resized_image
355
+
356
+
357
+ def move_mask_func(mask, direction, units):
358
+ binary_mask = mask.squeeze()>0
359
+ rows, cols = binary_mask.shape
360
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
361
+
362
+ if direction == 'down':
363
+ # move down
364
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
365
+
366
+ elif direction == 'up':
367
+ # move up
368
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
369
+
370
+ elif direction == 'right':
371
+ # move left
372
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
373
+
374
+ elif direction == 'left':
375
+ # move right
376
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
377
+
378
+ return moved_mask
379
+
380
+
381
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
382
+ # Randomly select the size of dilation
383
+ binary_mask = mask.squeeze()>0
384
+
385
+ if dilation_type == 'square_dilation':
386
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
387
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
388
+ elif dilation_type == 'square_erosion':
389
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
390
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
391
+ elif dilation_type == 'bounding_box':
392
+ # find the most left top and left bottom point
393
+ rows, cols = np.where(binary_mask)
394
+ if len(rows) == 0 or len(cols) == 0:
395
+ return mask # return original mask if no valid points
396
+
397
+ min_row = np.min(rows)
398
+ max_row = np.max(rows)
399
+ min_col = np.min(cols)
400
+ max_col = np.max(cols)
401
+
402
+ # create a bounding box
403
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
404
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
405
+
406
+ elif dilation_type == 'bounding_ellipse':
407
+ # find the most left top and left bottom point
408
+ rows, cols = np.where(binary_mask)
409
+ if len(rows) == 0 or len(cols) == 0:
410
+ return mask # return original mask if no valid points
411
+
412
+ min_row = np.min(rows)
413
+ max_row = np.max(rows)
414
+ min_col = np.min(cols)
415
+ max_col = np.max(cols)
416
+
417
+ # calculate the center and axis length of the ellipse
418
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
419
+ a = (max_col - min_col) // 2 # half long axis
420
+ b = (max_row - min_row) // 2 # half short axis
421
+
422
+ # create a bounding ellipse
423
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
424
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
425
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
426
+ dilated_mask[ellipse_mask] = True
427
+ else:
428
+ ValueError("dilation_type must be 'square' or 'ellipse'")
429
+
430
+ # use binary dilation
431
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
432
+ return dilated_mask
433
+
434
+
435
+ ## Gradio component function
436
+ def update_vlm_model(vlm_name):
437
+ global vlm_model, vlm_processor
438
+ if vlm_model is not None:
439
+ del vlm_model
440
+ torch.cuda.empty_cache()
441
+
442
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
443
+
444
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
445
+ if vlm_type == "llava-next":
446
+ if vlm_processor != "" and vlm_model != "":
447
+ vlm_model.to(device)
448
+ return vlm_model_dropdown
449
+ else:
450
+ if os.path.exists(vlm_local_path):
451
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
452
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
453
+ else:
454
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
455
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
456
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
457
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
458
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
459
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
460
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
461
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
462
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
463
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
464
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
465
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
466
+ elif vlm_name == "llava-next-72b-hf (Preload)":
467
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
468
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
469
+ elif vlm_type == "qwen2-vl":
470
+ if vlm_processor != "" and vlm_model != "":
471
+ vlm_model.to(device)
472
+ return vlm_model_dropdown
473
+ else:
474
+ if os.path.exists(vlm_local_path):
475
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
476
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
477
+ else:
478
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
479
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
480
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
481
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
482
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
483
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
484
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
485
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
486
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
487
+ elif vlm_type == "openai":
488
+ pass
489
+ return "success"
490
+
491
+
492
+ def update_base_model(base_model_name):
493
+ global pipe
494
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
495
+ if pipe is not None:
496
+ del pipe
497
+ torch.cuda.empty_cache()
498
+ base_model_path, pipe = base_models_template[base_model_name]
499
+ if pipe != "":
500
+ pipe.to(device)
501
+ else:
502
+ if os.path.exists(base_model_path):
503
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
504
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
505
+ )
506
+ # pipe.enable_xformers_memory_efficient_attention()
507
+ pipe.enable_model_cpu_offload()
508
+ else:
509
+ raise gr.Error(f"The base model {base_model_name} does not exist")
510
+ return "success"
511
+
512
+
513
+ def submit_GPT4o_KEY(GPT4o_KEY):
514
+ global vlm_model, vlm_processor
515
+ if vlm_model is not None:
516
+ del vlm_model
517
+ torch.cuda.empty_cache()
518
+ try:
519
+ vlm_model = OpenAI(api_key=GPT4o_KEY)
520
+ vlm_processor = ""
521
+ response = vlm_model.chat.completions.create(
522
+ model="gpt-4o-2024-08-06",
523
+ messages=[
524
+ {"role": "system", "content": "You are a helpful assistant."},
525
+ {"role": "user", "content": "Say this is a test"}
526
+ ]
527
+ )
528
+ response_str = response.choices[0].message.content
529
+
530
+ return "Success, " + response_str, "GPT4-o (Highly Recommended)"
531
+ except Exception as e:
532
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
533
+
534
+
535
+
536
+ def process(input_image,
537
+ original_image,
538
+ original_mask,
539
+ prompt,
540
+ negative_prompt,
541
+ control_strength,
542
+ seed,
543
+ randomize_seed,
544
+ guidance_scale,
545
+ num_inference_steps,
546
+ num_samples,
547
+ blending,
548
+ category,
549
+ target_prompt,
550
+ resize_default,
551
+ aspect_ratio_name,
552
+ invert_mask_state):
553
+ if original_image is None:
554
+ if input_image is None:
555
+ raise gr.Error('Please upload the input image')
556
+ else:
557
+ image_pil = input_image["background"].convert("RGB")
558
+ original_image = np.array(image_pil)
559
+ if prompt is None or prompt == "":
560
+ if target_prompt is None or target_prompt == "":
561
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
562
+
563
+ alpha_mask = input_image["layers"][0].split()[3]
564
+ input_mask = np.asarray(alpha_mask)
565
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
566
+ if output_w == "" or output_h == "":
567
+ output_h, output_w = original_image.shape[:2]
568
+
569
+ if resize_default:
570
+ short_side = min(output_w, output_h)
571
+ scale_ratio = 640 / short_side
572
+ output_w = int(output_w * scale_ratio)
573
+ output_h = int(output_h * scale_ratio)
574
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
575
+ original_image = np.array(original_image)
576
+ if input_mask is not None:
577
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
578
+ input_mask = np.array(input_mask)
579
+ if original_mask is not None:
580
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
581
+ original_mask = np.array(original_mask)
582
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
583
+ else:
584
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
585
+ pass
586
+ else:
587
+ if resize_default:
588
+ short_side = min(output_w, output_h)
589
+ scale_ratio = 640 / short_side
590
+ output_w = int(output_w * scale_ratio)
591
+ output_h = int(output_h * scale_ratio)
592
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
593
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
594
+ original_image = np.array(original_image)
595
+ if input_mask is not None:
596
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
597
+ input_mask = np.array(input_mask)
598
+ if original_mask is not None:
599
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
600
+ original_mask = np.array(original_mask)
601
+
602
+ if invert_mask_state:
603
+ original_mask = original_mask
604
+ else:
605
+ if input_mask.max() == 0:
606
+ original_mask = original_mask
607
+ else:
608
+ original_mask = input_mask
609
+
610
+
611
+ ## inpainting directly if target_prompt is not None
612
+ if category is not None:
613
+ pass
614
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
615
+ pass
616
+ else:
617
+ try:
618
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
619
+ except Exception as e:
620
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
621
+
622
+
623
+ if original_mask is not None:
624
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
625
+ else:
626
+ try:
627
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
628
+ vlm_processor,
629
+ vlm_model,
630
+ original_image,
631
+ category,
632
+ prompt,
633
+ device)
634
+
635
+ original_mask = vlm_response_mask(vlm_processor,
636
+ vlm_model,
637
+ category,
638
+ original_image,
639
+ prompt,
640
+ object_wait_for_edit,
641
+ sam,
642
+ sam_predictor,
643
+ sam_automask_generator,
644
+ groundingdino_model,
645
+ device).astype(np.uint8)
646
+ except Exception as e:
647
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
648
+
649
+ if original_mask.ndim == 2:
650
+ original_mask = original_mask[:,:,None]
651
+
652
+
653
+ if target_prompt is not None and len(target_prompt) >= 1:
654
+ prompt_after_apply_instruction = target_prompt
655
+
656
+ else:
657
+ try:
658
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
659
+ vlm_processor,
660
+ vlm_model,
661
+ original_image,
662
+ prompt,
663
+ device)
664
+ except Exception as e:
665
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
666
+
667
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
668
+
669
+
670
+ with torch.autocast(device):
671
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
672
+ prompt_after_apply_instruction,
673
+ original_mask,
674
+ original_image,
675
+ generator,
676
+ num_inference_steps,
677
+ guidance_scale,
678
+ control_strength,
679
+ negative_prompt,
680
+ num_samples,
681
+ blending)
682
+ original_image = np.array(init_image_np)
683
+ masked_image = original_image * (1 - (mask_np>0))
684
+ masked_image = masked_image.astype(np.uint8)
685
+ masked_image = Image.fromarray(masked_image)
686
+ # Save the images (optional)
687
+ # import uuid
688
+ # uuid = str(uuid.uuid4())
689
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
690
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
691
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
692
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
693
+ # mask_image.save(f"outputs/mask_{uuid}.png")
694
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
695
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
696
+ return image, [mask_image], [masked_image], prompt, '', False
697
+
698
+
699
+ def generate_target_prompt(input_image,
700
+ original_image,
701
+ prompt):
702
+ # load example image
703
+ if isinstance(original_image, str):
704
+ original_image = input_image
705
+
706
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
707
+ vlm_processor,
708
+ vlm_model,
709
+ original_image,
710
+ prompt,
711
+ device)
712
+ return prompt_after_apply_instruction
713
+
714
+
715
+ def process_mask(input_image,
716
+ original_image,
717
+ prompt,
718
+ resize_default,
719
+ aspect_ratio_name):
720
+ if original_image is None:
721
+ raise gr.Error('Please upload the input image')
722
+ if prompt is None:
723
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
724
+
725
+ ## load mask
726
+ alpha_mask = input_image["layers"][0].split()[3]
727
+ input_mask = np.array(alpha_mask)
728
+
729
+ # load example image
730
+ if isinstance(original_image, str):
731
+ original_image = input_image["background"]
732
+
733
+ if input_mask.max() == 0:
734
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
735
+
736
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
737
+ vlm_model,
738
+ original_image,
739
+ category,
740
+ prompt,
741
+ device)
742
+ # original mask: h,w,1 [0, 255]
743
+ original_mask = vlm_response_mask(
744
+ vlm_processor,
745
+ vlm_model,
746
+ category,
747
+ original_image,
748
+ prompt,
749
+ object_wait_for_edit,
750
+ sam,
751
+ sam_predictor,
752
+ sam_automask_generator,
753
+ groundingdino_model,
754
+ device).astype(np.uint8)
755
+ else:
756
+ original_mask = input_mask.astype(np.uint8)
757
+ category = None
758
+
759
+ ## resize mask if needed
760
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
761
+ if output_w == "" or output_h == "":
762
+ output_h, output_w = original_image.shape[:2]
763
+ if resize_default:
764
+ short_side = min(output_w, output_h)
765
+ scale_ratio = 640 / short_side
766
+ output_w = int(output_w * scale_ratio)
767
+ output_h = int(output_h * scale_ratio)
768
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
769
+ original_image = np.array(original_image)
770
+ if input_mask is not None:
771
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
772
+ input_mask = np.array(input_mask)
773
+ if original_mask is not None:
774
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
775
+ original_mask = np.array(original_mask)
776
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
777
+ else:
778
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
779
+ pass
780
+ else:
781
+ if resize_default:
782
+ short_side = min(output_w, output_h)
783
+ scale_ratio = 640 / short_side
784
+ output_w = int(output_w * scale_ratio)
785
+ output_h = int(output_h * scale_ratio)
786
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
787
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
788
+ original_image = np.array(original_image)
789
+ if input_mask is not None:
790
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
791
+ input_mask = np.array(input_mask)
792
+ if original_mask is not None:
793
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
794
+ original_mask = np.array(original_mask)
795
+
796
+
797
+ if original_mask.ndim == 2:
798
+ original_mask = original_mask[:,:,None]
799
+
800
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
801
+
802
+ masked_image = original_image * (1 - (original_mask>0))
803
+ masked_image = masked_image.astype(np.uint8)
804
+ masked_image = Image.fromarray(masked_image)
805
+
806
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
807
+
808
+
809
+ def process_random_mask(input_image,
810
+ original_image,
811
+ original_mask,
812
+ resize_default,
813
+ aspect_ratio_name,
814
+ ):
815
+
816
+ alpha_mask = input_image["layers"][0].split()[3]
817
+ input_mask = np.asarray(alpha_mask)
818
+
819
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
820
+ if output_w == "" or output_h == "":
821
+ output_h, output_w = original_image.shape[:2]
822
+ if resize_default:
823
+ short_side = min(output_w, output_h)
824
+ scale_ratio = 640 / short_side
825
+ output_w = int(output_w * scale_ratio)
826
+ output_h = int(output_h * scale_ratio)
827
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
828
+ original_image = np.array(original_image)
829
+ if input_mask is not None:
830
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
831
+ input_mask = np.array(input_mask)
832
+ if original_mask is not None:
833
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
834
+ original_mask = np.array(original_mask)
835
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
836
+ else:
837
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
838
+ pass
839
+ else:
840
+ if resize_default:
841
+ short_side = min(output_w, output_h)
842
+ scale_ratio = 640 / short_side
843
+ output_w = int(output_w * scale_ratio)
844
+ output_h = int(output_h * scale_ratio)
845
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
846
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
847
+ original_image = np.array(original_image)
848
+ if input_mask is not None:
849
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
850
+ input_mask = np.array(input_mask)
851
+ if original_mask is not None:
852
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
853
+ original_mask = np.array(original_mask)
854
+
855
+
856
+ if input_mask.max() == 0:
857
+ original_mask = original_mask
858
+ else:
859
+ original_mask = input_mask
860
+
861
+ if original_mask is None:
862
+ raise gr.Error('Please generate mask first')
863
+
864
+ if original_mask.ndim == 2:
865
+ original_mask = original_mask[:,:,None]
866
+
867
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
868
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
869
+
870
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
871
+
872
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
873
+ masked_image = masked_image.astype(original_image.dtype)
874
+ masked_image = Image.fromarray(masked_image)
875
+
876
+
877
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
878
+
879
+
880
+ def process_dilation_mask(input_image,
881
+ original_image,
882
+ original_mask,
883
+ resize_default,
884
+ aspect_ratio_name,
885
+ dilation_size=20):
886
+
887
+ alpha_mask = input_image["layers"][0].split()[3]
888
+ input_mask = np.asarray(alpha_mask)
889
+
890
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
891
+ if output_w == "" or output_h == "":
892
+ output_h, output_w = original_image.shape[:2]
893
+ if resize_default:
894
+ short_side = min(output_w, output_h)
895
+ scale_ratio = 640 / short_side
896
+ output_w = int(output_w * scale_ratio)
897
+ output_h = int(output_h * scale_ratio)
898
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
899
+ original_image = np.array(original_image)
900
+ if input_mask is not None:
901
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
902
+ input_mask = np.array(input_mask)
903
+ if original_mask is not None:
904
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
905
+ original_mask = np.array(original_mask)
906
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
907
+ else:
908
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
909
+ pass
910
+ else:
911
+ if resize_default:
912
+ short_side = min(output_w, output_h)
913
+ scale_ratio = 640 / short_side
914
+ output_w = int(output_w * scale_ratio)
915
+ output_h = int(output_h * scale_ratio)
916
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
917
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
918
+ original_image = np.array(original_image)
919
+ if input_mask is not None:
920
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
921
+ input_mask = np.array(input_mask)
922
+ if original_mask is not None:
923
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
924
+ original_mask = np.array(original_mask)
925
+
926
+ if input_mask.max() == 0:
927
+ original_mask = original_mask
928
+ else:
929
+ original_mask = input_mask
930
+
931
+ if original_mask is None:
932
+ raise gr.Error('Please generate mask first')
933
+
934
+ if original_mask.ndim == 2:
935
+ original_mask = original_mask[:,:,None]
936
+
937
+ dilation_type = np.random.choice(['square_dilation'])
938
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
939
+
940
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
941
+
942
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
943
+ masked_image = masked_image.astype(original_image.dtype)
944
+ masked_image = Image.fromarray(masked_image)
945
+
946
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
947
+
948
+
949
+ def process_erosion_mask(input_image,
950
+ original_image,
951
+ original_mask,
952
+ resize_default,
953
+ aspect_ratio_name,
954
+ dilation_size=20):
955
+ alpha_mask = input_image["layers"][0].split()[3]
956
+ input_mask = np.asarray(alpha_mask)
957
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
958
+ if output_w == "" or output_h == "":
959
+ output_h, output_w = original_image.shape[:2]
960
+ if resize_default:
961
+ short_side = min(output_w, output_h)
962
+ scale_ratio = 640 / short_side
963
+ output_w = int(output_w * scale_ratio)
964
+ output_h = int(output_h * scale_ratio)
965
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
966
+ original_image = np.array(original_image)
967
+ if input_mask is not None:
968
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
969
+ input_mask = np.array(input_mask)
970
+ if original_mask is not None:
971
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
972
+ original_mask = np.array(original_mask)
973
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
974
+ else:
975
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
976
+ pass
977
+ else:
978
+ if resize_default:
979
+ short_side = min(output_w, output_h)
980
+ scale_ratio = 640 / short_side
981
+ output_w = int(output_w * scale_ratio)
982
+ output_h = int(output_h * scale_ratio)
983
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
984
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
985
+ original_image = np.array(original_image)
986
+ if input_mask is not None:
987
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
988
+ input_mask = np.array(input_mask)
989
+ if original_mask is not None:
990
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
991
+ original_mask = np.array(original_mask)
992
+
993
+ if input_mask.max() == 0:
994
+ original_mask = original_mask
995
+ else:
996
+ original_mask = input_mask
997
+
998
+ if original_mask is None:
999
+ raise gr.Error('Please generate mask first')
1000
+
1001
+ if original_mask.ndim == 2:
1002
+ original_mask = original_mask[:,:,None]
1003
+
1004
+ dilation_type = np.random.choice(['square_erosion'])
1005
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1006
+
1007
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1008
+
1009
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1010
+ masked_image = masked_image.astype(original_image.dtype)
1011
+ masked_image = Image.fromarray(masked_image)
1012
+
1013
+
1014
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1015
+
1016
+
1017
+ def move_mask_left(input_image,
1018
+ original_image,
1019
+ original_mask,
1020
+ moving_pixels,
1021
+ resize_default,
1022
+ aspect_ratio_name):
1023
+
1024
+ alpha_mask = input_image["layers"][0].split()[3]
1025
+ input_mask = np.asarray(alpha_mask)
1026
+
1027
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1028
+ if output_w == "" or output_h == "":
1029
+ output_h, output_w = original_image.shape[:2]
1030
+ if resize_default:
1031
+ short_side = min(output_w, output_h)
1032
+ scale_ratio = 640 / short_side
1033
+ output_w = int(output_w * scale_ratio)
1034
+ output_h = int(output_h * scale_ratio)
1035
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1036
+ original_image = np.array(original_image)
1037
+ if input_mask is not None:
1038
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1039
+ input_mask = np.array(input_mask)
1040
+ if original_mask is not None:
1041
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1042
+ original_mask = np.array(original_mask)
1043
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1044
+ else:
1045
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1046
+ pass
1047
+ else:
1048
+ if resize_default:
1049
+ short_side = min(output_w, output_h)
1050
+ scale_ratio = 640 / short_side
1051
+ output_w = int(output_w * scale_ratio)
1052
+ output_h = int(output_h * scale_ratio)
1053
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1054
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1055
+ original_image = np.array(original_image)
1056
+ if input_mask is not None:
1057
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1058
+ input_mask = np.array(input_mask)
1059
+ if original_mask is not None:
1060
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1061
+ original_mask = np.array(original_mask)
1062
+
1063
+ if input_mask.max() == 0:
1064
+ original_mask = original_mask
1065
+ else:
1066
+ original_mask = input_mask
1067
+
1068
+ if original_mask is None:
1069
+ raise gr.Error('Please generate mask first')
1070
+
1071
+ if original_mask.ndim == 2:
1072
+ original_mask = original_mask[:,:,None]
1073
+
1074
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1075
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1076
+
1077
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1078
+ masked_image = masked_image.astype(original_image.dtype)
1079
+ masked_image = Image.fromarray(masked_image)
1080
+
1081
+ if moved_mask.max() <= 1:
1082
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1083
+ original_mask = moved_mask
1084
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1085
+
1086
+
1087
+ def move_mask_right(input_image,
1088
+ original_image,
1089
+ original_mask,
1090
+ moving_pixels,
1091
+ resize_default,
1092
+ aspect_ratio_name):
1093
+ alpha_mask = input_image["layers"][0].split()[3]
1094
+ input_mask = np.asarray(alpha_mask)
1095
+
1096
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1097
+ if output_w == "" or output_h == "":
1098
+ output_h, output_w = original_image.shape[:2]
1099
+ if resize_default:
1100
+ short_side = min(output_w, output_h)
1101
+ scale_ratio = 640 / short_side
1102
+ output_w = int(output_w * scale_ratio)
1103
+ output_h = int(output_h * scale_ratio)
1104
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1105
+ original_image = np.array(original_image)
1106
+ if input_mask is not None:
1107
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1108
+ input_mask = np.array(input_mask)
1109
+ if original_mask is not None:
1110
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1111
+ original_mask = np.array(original_mask)
1112
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1113
+ else:
1114
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1115
+ pass
1116
+ else:
1117
+ if resize_default:
1118
+ short_side = min(output_w, output_h)
1119
+ scale_ratio = 640 / short_side
1120
+ output_w = int(output_w * scale_ratio)
1121
+ output_h = int(output_h * scale_ratio)
1122
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1123
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1124
+ original_image = np.array(original_image)
1125
+ if input_mask is not None:
1126
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1127
+ input_mask = np.array(input_mask)
1128
+ if original_mask is not None:
1129
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1130
+ original_mask = np.array(original_mask)
1131
+
1132
+ if input_mask.max() == 0:
1133
+ original_mask = original_mask
1134
+ else:
1135
+ original_mask = input_mask
1136
+
1137
+ if original_mask is None:
1138
+ raise gr.Error('Please generate mask first')
1139
+
1140
+ if original_mask.ndim == 2:
1141
+ original_mask = original_mask[:,:,None]
1142
+
1143
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1144
+
1145
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1146
+
1147
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1148
+ masked_image = masked_image.astype(original_image.dtype)
1149
+ masked_image = Image.fromarray(masked_image)
1150
+
1151
+
1152
+ if moved_mask.max() <= 1:
1153
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1154
+ original_mask = moved_mask
1155
+
1156
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1157
+
1158
+
1159
+ def move_mask_up(input_image,
1160
+ original_image,
1161
+ original_mask,
1162
+ moving_pixels,
1163
+ resize_default,
1164
+ aspect_ratio_name):
1165
+ alpha_mask = input_image["layers"][0].split()[3]
1166
+ input_mask = np.asarray(alpha_mask)
1167
+
1168
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1169
+ if output_w == "" or output_h == "":
1170
+ output_h, output_w = original_image.shape[:2]
1171
+ if resize_default:
1172
+ short_side = min(output_w, output_h)
1173
+ scale_ratio = 640 / short_side
1174
+ output_w = int(output_w * scale_ratio)
1175
+ output_h = int(output_h * scale_ratio)
1176
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1177
+ original_image = np.array(original_image)
1178
+ if input_mask is not None:
1179
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1180
+ input_mask = np.array(input_mask)
1181
+ if original_mask is not None:
1182
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1183
+ original_mask = np.array(original_mask)
1184
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1185
+ else:
1186
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1187
+ pass
1188
+ else:
1189
+ if resize_default:
1190
+ short_side = min(output_w, output_h)
1191
+ scale_ratio = 640 / short_side
1192
+ output_w = int(output_w * scale_ratio)
1193
+ output_h = int(output_h * scale_ratio)
1194
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1195
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1196
+ original_image = np.array(original_image)
1197
+ if input_mask is not None:
1198
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1199
+ input_mask = np.array(input_mask)
1200
+ if original_mask is not None:
1201
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1202
+ original_mask = np.array(original_mask)
1203
+
1204
+ if input_mask.max() == 0:
1205
+ original_mask = original_mask
1206
+ else:
1207
+ original_mask = input_mask
1208
+
1209
+ if original_mask is None:
1210
+ raise gr.Error('Please generate mask first')
1211
+
1212
+ if original_mask.ndim == 2:
1213
+ original_mask = original_mask[:,:,None]
1214
+
1215
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1216
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1217
+
1218
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1219
+ masked_image = masked_image.astype(original_image.dtype)
1220
+ masked_image = Image.fromarray(masked_image)
1221
+
1222
+ if moved_mask.max() <= 1:
1223
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1224
+ original_mask = moved_mask
1225
+
1226
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1227
+
1228
+
1229
+ def move_mask_down(input_image,
1230
+ original_image,
1231
+ original_mask,
1232
+ moving_pixels,
1233
+ resize_default,
1234
+ aspect_ratio_name):
1235
+ alpha_mask = input_image["layers"][0].split()[3]
1236
+ input_mask = np.asarray(alpha_mask)
1237
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1238
+ if output_w == "" or output_h == "":
1239
+ output_h, output_w = original_image.shape[:2]
1240
+ if resize_default:
1241
+ short_side = min(output_w, output_h)
1242
+ scale_ratio = 640 / short_side
1243
+ output_w = int(output_w * scale_ratio)
1244
+ output_h = int(output_h * scale_ratio)
1245
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1246
+ original_image = np.array(original_image)
1247
+ if input_mask is not None:
1248
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1249
+ input_mask = np.array(input_mask)
1250
+ if original_mask is not None:
1251
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1252
+ original_mask = np.array(original_mask)
1253
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1254
+ else:
1255
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1256
+ pass
1257
+ else:
1258
+ if resize_default:
1259
+ short_side = min(output_w, output_h)
1260
+ scale_ratio = 640 / short_side
1261
+ output_w = int(output_w * scale_ratio)
1262
+ output_h = int(output_h * scale_ratio)
1263
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1264
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1265
+ original_image = np.array(original_image)
1266
+ if input_mask is not None:
1267
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1268
+ input_mask = np.array(input_mask)
1269
+ if original_mask is not None:
1270
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1271
+ original_mask = np.array(original_mask)
1272
+
1273
+ if input_mask.max() == 0:
1274
+ original_mask = original_mask
1275
+ else:
1276
+ original_mask = input_mask
1277
+
1278
+ if original_mask is None:
1279
+ raise gr.Error('Please generate mask first')
1280
+
1281
+ if original_mask.ndim == 2:
1282
+ original_mask = original_mask[:,:,None]
1283
+
1284
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1285
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1286
+
1287
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1288
+ masked_image = masked_image.astype(original_image.dtype)
1289
+ masked_image = Image.fromarray(masked_image)
1290
+
1291
+ if moved_mask.max() <= 1:
1292
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1293
+ original_mask = moved_mask
1294
+
1295
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1296
+
1297
+
1298
+ def invert_mask(input_image,
1299
+ original_image,
1300
+ original_mask,
1301
+ ):
1302
+ alpha_mask = input_image["layers"][0].split()[3]
1303
+ input_mask = np.asarray(alpha_mask)
1304
+ if input_mask.max() == 0:
1305
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1306
+ else:
1307
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1308
+
1309
+ if original_mask is None:
1310
+ raise gr.Error('Please generate mask first')
1311
+
1312
+ original_mask = original_mask.squeeze()
1313
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1314
+
1315
+ if original_mask.ndim == 2:
1316
+ original_mask = original_mask[:,:,None]
1317
+
1318
+ if original_mask.max() <= 1:
1319
+ original_mask = (original_mask * 255).astype(np.uint8)
1320
+
1321
+ masked_image = original_image * (1 - (original_mask>0))
1322
+ masked_image = masked_image.astype(original_image.dtype)
1323
+ masked_image = Image.fromarray(masked_image)
1324
+
1325
+ return [masked_image], [mask_image], original_mask, True
1326
+
1327
+
1328
+ def init_img(base,
1329
+ init_type,
1330
+ prompt,
1331
+ aspect_ratio,
1332
+ example_change_times
1333
+ ):
1334
+ image_pil = base["background"].convert("RGB")
1335
+ original_image = np.array(image_pil)
1336
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1337
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1338
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1339
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1340
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1341
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1342
+ width, height = image_pil.size
1343
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1344
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1345
+ image_pil = image_pil.resize((width_new, height_new))
1346
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1347
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1348
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1349
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1350
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1351
+ else:
1352
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1353
+ aspect_ratio = "Custom resolution"
1354
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1355
+
1356
+
1357
+ def reset_func(input_image,
1358
+ original_image,
1359
+ original_mask,
1360
+ prompt,
1361
+ target_prompt,
1362
+ ):
1363
+ input_image = None
1364
+ original_image = None
1365
+ original_mask = None
1366
+ prompt = ''
1367
+ mask_gallery = []
1368
+ masked_gallery = []
1369
+ result_gallery = []
1370
+ target_prompt = ''
1371
+ if torch.cuda.is_available():
1372
+ torch.cuda.empty_cache()
1373
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1374
+
1375
+
1376
+ def update_example(example_type,
1377
+ prompt,
1378
+ example_change_times):
1379
+ input_image = INPUT_IMAGE_PATH[example_type]
1380
+ image_pil = Image.open(input_image).convert("RGB")
1381
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1382
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1383
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1384
+ width, height = image_pil.size
1385
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1386
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1387
+ image_pil = image_pil.resize((width_new, height_new))
1388
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1389
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1390
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1391
+
1392
+ original_image = np.array(image_pil)
1393
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1394
+ aspect_ratio = "Custom resolution"
1395
+ example_change_times += 1
1396
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1397
+
1398
+
1399
+ block = gr.Blocks(
1400
+ theme=gr.themes.Soft(
1401
+ radius_size=gr.themes.sizes.radius_none,
1402
+ text_size=gr.themes.sizes.text_md
1403
+ )
1404
+ )
1405
+ with block as demo:
1406
+ with gr.Row():
1407
+ with gr.Column():
1408
+ gr.HTML(head)
1409
+
1410
+ gr.Markdown(descriptions)
1411
+
1412
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1413
+ with gr.Row(equal_height=True):
1414
+ gr.Markdown(instructions)
1415
+
1416
+ original_image = gr.State(value=None)
1417
+ original_mask = gr.State(value=None)
1418
+ category = gr.State(value=None)
1419
+ status = gr.State(value=None)
1420
+ invert_mask_state = gr.State(value=False)
1421
+ example_change_times = gr.State(value=0)
1422
+
1423
+
1424
+ with gr.Row():
1425
+ with gr.Column():
1426
+ with gr.Row():
1427
+ input_image = gr.ImageEditor(
1428
+ label="Input Image",
1429
+ type="pil",
1430
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1431
+ layers = False,
1432
+ interactive=True,
1433
+ height=1024,
1434
+ sources=["upload"],
1435
+ placeholder="Please click here or the icon below to upload the image.",
1436
+ )
1437
+
1438
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1439
+ run_button = gr.Button("💫 Run")
1440
+
1441
+ vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1442
+ with gr.Group():
1443
+ with gr.Row():
1444
+ GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1445
+
1446
+ GPT4o_KEY_submit = gr.Button("Submit and Verify")
1447
+
1448
+
1449
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1450
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1451
+
1452
+ with gr.Row():
1453
+ mask_button = gr.Button("Generate Mask")
1454
+ random_mask_button = gr.Button("Square/Circle Mask ")
1455
+
1456
+
1457
+ with gr.Row():
1458
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1459
+
1460
+ target_prompt = gr.Text(
1461
+ label="Input Target Prompt",
1462
+ max_lines=5,
1463
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1464
+ value='',
1465
+ lines=2
1466
+ )
1467
+
1468
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1469
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1470
+ negative_prompt = gr.Text(
1471
+ label="Negative Prompt",
1472
+ max_lines=5,
1473
+ placeholder="Please input your negative prompt",
1474
+ value='ugly, low quality',lines=1
1475
+ )
1476
+
1477
+ control_strength = gr.Slider(
1478
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1479
+ )
1480
+ with gr.Group():
1481
+ seed = gr.Slider(
1482
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1483
+ )
1484
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1485
+
1486
+ blending = gr.Checkbox(label="Blending mode", value=True)
1487
+
1488
+
1489
+ num_samples = gr.Slider(
1490
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1491
+ )
1492
+
1493
+ with gr.Group():
1494
+ with gr.Row():
1495
+ guidance_scale = gr.Slider(
1496
+ label="Guidance scale",
1497
+ minimum=1,
1498
+ maximum=12,
1499
+ step=0.1,
1500
+ value=7.5,
1501
+ )
1502
+ num_inference_steps = gr.Slider(
1503
+ label="Number of inference steps",
1504
+ minimum=1,
1505
+ maximum=50,
1506
+ step=1,
1507
+ value=50,
1508
+ )
1509
+
1510
+
1511
+ with gr.Column():
1512
+ with gr.Row():
1513
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1514
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1515
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1516
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1517
+
1518
+ invert_mask_button = gr.Button("Invert Mask")
1519
+ dilation_size = gr.Slider(
1520
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1521
+ )
1522
+ with gr.Row():
1523
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1524
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1525
+
1526
+ moving_pixels = gr.Slider(
1527
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1528
+ )
1529
+ with gr.Row():
1530
+ move_left_button = gr.Button("Move Left")
1531
+ move_right_button = gr.Button("Move Right")
1532
+ with gr.Row():
1533
+ move_up_button = gr.Button("Move Up")
1534
+ move_down_button = gr.Button("Move Down")
1535
+
1536
+ with gr.Tab(elem_classes="feedback", label="Output"):
1537
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1538
+
1539
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1540
+
1541
+ reset_button = gr.Button("Reset")
1542
+
1543
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1544
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1545
+
1546
+
1547
+
1548
+ with gr.Row():
1549
+ example = gr.Examples(
1550
+ label="Quick Example",
1551
+ examples=EXAMPLES,
1552
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1553
+ examples_per_page=10,
1554
+ cache_examples=False,
1555
+ )
1556
+
1557
+
1558
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1559
+ with gr.Row(equal_height=True):
1560
+ gr.Markdown(tips)
1561
+
1562
+ with gr.Row():
1563
+ gr.Markdown(citation)
1564
+
1565
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1566
+ ## And we need to solve the conflict between the upload and change example functions.
1567
+ input_image.upload(
1568
+ init_img,
1569
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1570
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1571
+ )
1572
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1573
+
1574
+ ## vlm and base model dropdown
1575
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1576
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1577
+
1578
+
1579
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1580
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1581
+
1582
+
1583
+ ips=[input_image,
1584
+ original_image,
1585
+ original_mask,
1586
+ prompt,
1587
+ negative_prompt,
1588
+ control_strength,
1589
+ seed,
1590
+ randomize_seed,
1591
+ guidance_scale,
1592
+ num_inference_steps,
1593
+ num_samples,
1594
+ blending,
1595
+ category,
1596
+ target_prompt,
1597
+ resize_default,
1598
+ aspect_ratio,
1599
+ invert_mask_state]
1600
+
1601
+ ## run brushedit
1602
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1603
+
1604
+ ## mask func
1605
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1606
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1607
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1608
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1609
+
1610
+ ## move mask func
1611
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1612
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1613
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1614
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1615
+
1616
+ ## prompt func
1617
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1618
+
1619
+ ## reset func
1620
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1621
+
1622
+ ## if have a localhost access error, try to use the following code
1623
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1624
+ # demo.launch()
brushedit_app_315_2.py ADDED
@@ -0,0 +1,1627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ # base_openai_url = ""
39
+
40
+ #### Description ####
41
+ logo = r"""
42
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
43
+ """
44
+ head = r"""
45
+ <div style="text-align: center;">
46
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
47
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
48
+ <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
49
+ <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
50
+ <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
51
+
52
+ </div>
53
+ </br>
54
+ </div>
55
+ """
56
+ descriptions = r"""
57
+ Official Gradio Demo
58
+ """
59
+
60
+ instructions = r"""
61
+ 等待补充
62
+ """
63
+
64
+ tips = r"""
65
+ 等待补充
66
+
67
+ """
68
+
69
+
70
+
71
+ citation = r"""
72
+ 等待补充
73
+ """
74
+
75
+ # - - - - - examples - - - - - #
76
+ EXAMPLES = [
77
+
78
+ [
79
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
80
+ "add a magic hat on frog head.",
81
+ 642087011,
82
+ "frog",
83
+ "frog",
84
+ True,
85
+ False,
86
+ "GPT4-o (Highly Recommended)"
87
+ ],
88
+ [
89
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
90
+ "replace the background to ancient China.",
91
+ 648464818,
92
+ "chinese_girl",
93
+ "chinese_girl",
94
+ True,
95
+ False,
96
+ "GPT4-o (Highly Recommended)"
97
+ ],
98
+ [
99
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
100
+ "remove the deer.",
101
+ 648464818,
102
+ "angel_christmas",
103
+ "angel_christmas",
104
+ False,
105
+ False,
106
+ "GPT4-o (Highly Recommended)"
107
+ ],
108
+ [
109
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
110
+ "add a wreath on head.",
111
+ 648464818,
112
+ "sunflower_girl",
113
+ "sunflower_girl",
114
+ True,
115
+ False,
116
+ "GPT4-o (Highly Recommended)"
117
+ ],
118
+ [
119
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
120
+ "add a butterfly fairy.",
121
+ 648464818,
122
+ "girl_on_sun",
123
+ "girl_on_sun",
124
+ True,
125
+ False,
126
+ "GPT4-o (Highly Recommended)"
127
+ ],
128
+ [
129
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
130
+ "remove the christmas hat.",
131
+ 642087011,
132
+ "spider_man_rm",
133
+ "spider_man_rm",
134
+ False,
135
+ False,
136
+ "GPT4-o (Highly Recommended)"
137
+ ],
138
+ [
139
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
140
+ "remove the flower.",
141
+ 642087011,
142
+ "anime_flower",
143
+ "anime_flower",
144
+ False,
145
+ False,
146
+ "GPT4-o (Highly Recommended)"
147
+ ],
148
+ [
149
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
150
+ "replace the clothes to a delicated floral skirt.",
151
+ 648464818,
152
+ "chenduling",
153
+ "chenduling",
154
+ True,
155
+ False,
156
+ "GPT4-o (Highly Recommended)"
157
+ ],
158
+ [
159
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
160
+ "make the hedgehog in Italy.",
161
+ 648464818,
162
+ "hedgehog_rp_bg",
163
+ "hedgehog_rp_bg",
164
+ True,
165
+ False,
166
+ "GPT4-o (Highly Recommended)"
167
+ ],
168
+
169
+ ]
170
+
171
+ INPUT_IMAGE_PATH = {
172
+ "frog": "./assets/frog/frog.jpeg",
173
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
174
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
175
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
176
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
177
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
178
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
179
+ "chenduling": "./assets/chenduling/chengduling.jpg",
180
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
181
+ }
182
+ MASK_IMAGE_PATH = {
183
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
184
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
185
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
186
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
187
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
188
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
189
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
190
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
191
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
192
+ }
193
+ MASKED_IMAGE_PATH = {
194
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
195
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
196
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
197
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
198
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
199
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
200
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
201
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
202
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
203
+ }
204
+ OUTPUT_IMAGE_PATH = {
205
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
206
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
207
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
208
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
209
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
210
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
211
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
212
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
213
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
214
+ }
215
+
216
+
217
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
218
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
219
+
220
+ VLM_MODEL_NAMES = list(vlms_template.keys())
221
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
222
+ BASE_MODELS = list(base_models_template.keys())
223
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
224
+
225
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
226
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
227
+
228
+ ## init device
229
+ try:
230
+ if torch.cuda.is_available():
231
+ device = "cuda:0"
232
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
233
+ device = "mps"
234
+ else:
235
+ device = "cpu"
236
+ except:
237
+ device = "cpu"
238
+
239
+ ## init torch dtype
240
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
241
+ torch_dtype = torch.bfloat16
242
+ else:
243
+ torch_dtype = torch.float16
244
+
245
+ if device == "mps":
246
+ torch_dtype = torch.float16
247
+
248
+
249
+
250
+ # download hf models
251
+ BrushEdit_path = "models/"
252
+ if not os.path.exists(BrushEdit_path):
253
+ BrushEdit_path = snapshot_download(
254
+ repo_id="TencentARC/BrushEdit",
255
+ local_dir=BrushEdit_path,
256
+ token=os.getenv("HF_TOKEN"),
257
+ )
258
+
259
+ ## init default VLM
260
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
261
+ if vlm_processor != "" and vlm_model != "":
262
+ vlm_model.to(device)
263
+ else:
264
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
265
+
266
+
267
+ ## init base model
268
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
269
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
270
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
271
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
272
+
273
+
274
+ # input brushnetX ckpt path
275
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
276
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
277
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
278
+ )
279
+ # speed up diffusion process with faster scheduler and memory optimization
280
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
281
+ # remove following line if xformers is not installed or when using Torch 2.0.
282
+ # pipe.enable_xformers_memory_efficient_attention()
283
+ pipe.enable_model_cpu_offload()
284
+
285
+
286
+ ## init SAM
287
+ sam = build_sam(checkpoint=sam_path)
288
+ sam.to(device=device)
289
+ sam_predictor = SamPredictor(sam)
290
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
291
+
292
+ ## init groundingdino_model
293
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
294
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
295
+
296
+ ## Ordinary function
297
+ def crop_and_resize(image: Image.Image,
298
+ target_width: int,
299
+ target_height: int) -> Image.Image:
300
+ """
301
+ Crops and resizes an image while preserving the aspect ratio.
302
+
303
+ Args:
304
+ image (Image.Image): Input PIL image to be cropped and resized.
305
+ target_width (int): Target width of the output image.
306
+ target_height (int): Target height of the output image.
307
+
308
+ Returns:
309
+ Image.Image: Cropped and resized image.
310
+ """
311
+ # Original dimensions
312
+ original_width, original_height = image.size
313
+ original_aspect = original_width / original_height
314
+ target_aspect = target_width / target_height
315
+
316
+ # Calculate crop box to maintain aspect ratio
317
+ if original_aspect > target_aspect:
318
+ # Crop horizontally
319
+ new_width = int(original_height * target_aspect)
320
+ new_height = original_height
321
+ left = (original_width - new_width) / 2
322
+ top = 0
323
+ right = left + new_width
324
+ bottom = original_height
325
+ else:
326
+ # Crop vertically
327
+ new_width = original_width
328
+ new_height = int(original_width / target_aspect)
329
+ left = 0
330
+ top = (original_height - new_height) / 2
331
+ right = original_width
332
+ bottom = top + new_height
333
+
334
+ # Crop and resize
335
+ cropped_image = image.crop((left, top, right, bottom))
336
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
337
+ return resized_image
338
+
339
+
340
+ ## Ordinary function
341
+ def resize(image: Image.Image,
342
+ target_width: int,
343
+ target_height: int) -> Image.Image:
344
+ """
345
+ Crops and resizes an image while preserving the aspect ratio.
346
+
347
+ Args:
348
+ image (Image.Image): Input PIL image to be cropped and resized.
349
+ target_width (int): Target width of the output image.
350
+ target_height (int): Target height of the output image.
351
+
352
+ Returns:
353
+ Image.Image: Cropped and resized image.
354
+ """
355
+ # Original dimensions
356
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
357
+ return resized_image
358
+
359
+
360
+ def move_mask_func(mask, direction, units):
361
+ binary_mask = mask.squeeze()>0
362
+ rows, cols = binary_mask.shape
363
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
364
+
365
+ if direction == 'down':
366
+ # move down
367
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
368
+
369
+ elif direction == 'up':
370
+ # move up
371
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
372
+
373
+ elif direction == 'right':
374
+ # move left
375
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
376
+
377
+ elif direction == 'left':
378
+ # move right
379
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
380
+
381
+ return moved_mask
382
+
383
+
384
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
385
+ # Randomly select the size of dilation
386
+ binary_mask = mask.squeeze()>0
387
+
388
+ if dilation_type == 'square_dilation':
389
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
390
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
391
+ elif dilation_type == 'square_erosion':
392
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
393
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
394
+ elif dilation_type == 'bounding_box':
395
+ # find the most left top and left bottom point
396
+ rows, cols = np.where(binary_mask)
397
+ if len(rows) == 0 or len(cols) == 0:
398
+ return mask # return original mask if no valid points
399
+
400
+ min_row = np.min(rows)
401
+ max_row = np.max(rows)
402
+ min_col = np.min(cols)
403
+ max_col = np.max(cols)
404
+
405
+ # create a bounding box
406
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
407
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
408
+
409
+ elif dilation_type == 'bounding_ellipse':
410
+ # find the most left top and left bottom point
411
+ rows, cols = np.where(binary_mask)
412
+ if len(rows) == 0 or len(cols) == 0:
413
+ return mask # return original mask if no valid points
414
+
415
+ min_row = np.min(rows)
416
+ max_row = np.max(rows)
417
+ min_col = np.min(cols)
418
+ max_col = np.max(cols)
419
+
420
+ # calculate the center and axis length of the ellipse
421
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
422
+ a = (max_col - min_col) // 2 # half long axis
423
+ b = (max_row - min_row) // 2 # half short axis
424
+
425
+ # create a bounding ellipse
426
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
427
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
428
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
429
+ dilated_mask[ellipse_mask] = True
430
+ else:
431
+ ValueError("dilation_type must be 'square' or 'ellipse'")
432
+
433
+ # use binary dilation
434
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
435
+ return dilated_mask
436
+
437
+
438
+ ## Gradio component function
439
+ def update_vlm_model(vlm_name):
440
+ global vlm_model, vlm_processor
441
+ if vlm_model is not None:
442
+ del vlm_model
443
+ torch.cuda.empty_cache()
444
+
445
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
446
+
447
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
448
+ if vlm_type == "llava-next":
449
+ if vlm_processor != "" and vlm_model != "":
450
+ vlm_model.to(device)
451
+ return vlm_model_dropdown
452
+ else:
453
+ if os.path.exists(vlm_local_path):
454
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
455
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
456
+ else:
457
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
458
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
459
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
460
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
461
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
462
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
463
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
464
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
465
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
466
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
467
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
468
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
469
+ elif vlm_name == "llava-next-72b-hf (Preload)":
470
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
471
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
472
+ elif vlm_type == "qwen2-vl":
473
+ if vlm_processor != "" and vlm_model != "":
474
+ vlm_model.to(device)
475
+ return vlm_model_dropdown
476
+ else:
477
+ if os.path.exists(vlm_local_path):
478
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
479
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
480
+ else:
481
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
482
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
483
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
484
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
485
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
486
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
487
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
488
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
489
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
490
+ elif vlm_type == "openai":
491
+ pass
492
+ return "success"
493
+
494
+
495
+ def update_base_model(base_model_name):
496
+ global pipe
497
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
498
+ if pipe is not None:
499
+ del pipe
500
+ torch.cuda.empty_cache()
501
+ base_model_path, pipe = base_models_template[base_model_name]
502
+ if pipe != "":
503
+ pipe.to(device)
504
+ else:
505
+ if os.path.exists(base_model_path):
506
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
507
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
508
+ )
509
+ # pipe.enable_xformers_memory_efficient_attention()
510
+ pipe.enable_model_cpu_offload()
511
+ else:
512
+ raise gr.Error(f"The base model {base_model_name} does not exist")
513
+ return "success"
514
+
515
+
516
+ def submit_GPT4o_KEY(GPT4o_KEY):
517
+ global vlm_model, vlm_processor
518
+ if vlm_model is not None:
519
+ del vlm_model
520
+ torch.cuda.empty_cache()
521
+ try:
522
+ vlm_model = OpenAI(api_key=GPT4o_KEY)
523
+ vlm_processor = ""
524
+ response = vlm_model.chat.completions.create(
525
+ model="gpt-4o-2024-08-06",
526
+ messages=[
527
+ {"role": "system", "content": "You are a helpful assistant."},
528
+ {"role": "user", "content": "Say this is a test"}
529
+ ]
530
+ )
531
+ response_str = response.choices[0].message.content
532
+
533
+ return "Success, " + response_str, "GPT4-o (Highly Recommended)"
534
+ except Exception as e:
535
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
536
+
537
+
538
+
539
+ def process(input_image,
540
+ original_image,
541
+ original_mask,
542
+ prompt,
543
+ negative_prompt,
544
+ control_strength,
545
+ seed,
546
+ randomize_seed,
547
+ guidance_scale,
548
+ num_inference_steps,
549
+ num_samples,
550
+ blending,
551
+ category,
552
+ target_prompt,
553
+ resize_default,
554
+ aspect_ratio_name,
555
+ invert_mask_state):
556
+ if original_image is None:
557
+ if input_image is None:
558
+ raise gr.Error('Please upload the input image')
559
+ else:
560
+ image_pil = input_image["background"].convert("RGB")
561
+ original_image = np.array(image_pil)
562
+ if prompt is None or prompt == "":
563
+ if target_prompt is None or target_prompt == "":
564
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
565
+
566
+ alpha_mask = input_image["layers"][0].split()[3]
567
+ input_mask = np.asarray(alpha_mask)
568
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
569
+ if output_w == "" or output_h == "":
570
+ output_h, output_w = original_image.shape[:2]
571
+
572
+ if resize_default:
573
+ short_side = min(output_w, output_h)
574
+ scale_ratio = 640 / short_side
575
+ output_w = int(output_w * scale_ratio)
576
+ output_h = int(output_h * scale_ratio)
577
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
578
+ original_image = np.array(original_image)
579
+ if input_mask is not None:
580
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
581
+ input_mask = np.array(input_mask)
582
+ if original_mask is not None:
583
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
584
+ original_mask = np.array(original_mask)
585
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
586
+ else:
587
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
588
+ pass
589
+ else:
590
+ if resize_default:
591
+ short_side = min(output_w, output_h)
592
+ scale_ratio = 640 / short_side
593
+ output_w = int(output_w * scale_ratio)
594
+ output_h = int(output_h * scale_ratio)
595
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
596
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
597
+ original_image = np.array(original_image)
598
+ if input_mask is not None:
599
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
600
+ input_mask = np.array(input_mask)
601
+ if original_mask is not None:
602
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
603
+ original_mask = np.array(original_mask)
604
+
605
+ if invert_mask_state:
606
+ original_mask = original_mask
607
+ else:
608
+ if input_mask.max() == 0:
609
+ original_mask = original_mask
610
+ else:
611
+ original_mask = input_mask
612
+
613
+
614
+ ## inpainting directly if target_prompt is not None
615
+ if category is not None:
616
+ pass
617
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
618
+ pass
619
+ else:
620
+ try:
621
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
622
+ except Exception as e:
623
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
624
+
625
+
626
+ if original_mask is not None:
627
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
628
+ else:
629
+ try:
630
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
631
+ vlm_processor,
632
+ vlm_model,
633
+ original_image,
634
+ category,
635
+ prompt,
636
+ device)
637
+
638
+ original_mask = vlm_response_mask(vlm_processor,
639
+ vlm_model,
640
+ category,
641
+ original_image,
642
+ prompt,
643
+ object_wait_for_edit,
644
+ sam,
645
+ sam_predictor,
646
+ sam_automask_generator,
647
+ groundingdino_model,
648
+ device).astype(np.uint8)
649
+ except Exception as e:
650
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
651
+
652
+ if original_mask.ndim == 2:
653
+ original_mask = original_mask[:,:,None]
654
+
655
+
656
+ if target_prompt is not None and len(target_prompt) >= 1:
657
+ prompt_after_apply_instruction = target_prompt
658
+
659
+ else:
660
+ try:
661
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
662
+ vlm_processor,
663
+ vlm_model,
664
+ original_image,
665
+ prompt,
666
+ device)
667
+ except Exception as e:
668
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
669
+
670
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
671
+
672
+
673
+ with torch.autocast(device):
674
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
675
+ prompt_after_apply_instruction,
676
+ original_mask,
677
+ original_image,
678
+ generator,
679
+ num_inference_steps,
680
+ guidance_scale,
681
+ control_strength,
682
+ negative_prompt,
683
+ num_samples,
684
+ blending)
685
+ original_image = np.array(init_image_np)
686
+ masked_image = original_image * (1 - (mask_np>0))
687
+ masked_image = masked_image.astype(np.uint8)
688
+ masked_image = Image.fromarray(masked_image)
689
+ # Save the images (optional)
690
+ # import uuid
691
+ # uuid = str(uuid.uuid4())
692
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
693
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
694
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
695
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
696
+ # mask_image.save(f"outputs/mask_{uuid}.png")
697
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
698
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
699
+ return image, [mask_image], [masked_image], prompt, '', False
700
+
701
+
702
+ def generate_target_prompt(input_image,
703
+ original_image,
704
+ prompt):
705
+ # load example image
706
+ if isinstance(original_image, str):
707
+ original_image = input_image
708
+
709
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
710
+ vlm_processor,
711
+ vlm_model,
712
+ original_image,
713
+ prompt,
714
+ device)
715
+ return prompt_after_apply_instruction
716
+
717
+
718
+ def process_mask(input_image,
719
+ original_image,
720
+ prompt,
721
+ resize_default,
722
+ aspect_ratio_name):
723
+ if original_image is None:
724
+ raise gr.Error('Please upload the input image')
725
+ if prompt is None:
726
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
727
+
728
+ ## load mask
729
+ alpha_mask = input_image["layers"][0].split()[3]
730
+ input_mask = np.array(alpha_mask)
731
+
732
+ # load example image
733
+ if isinstance(original_image, str):
734
+ original_image = input_image["background"]
735
+
736
+ if input_mask.max() == 0:
737
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
738
+
739
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
740
+ vlm_model,
741
+ original_image,
742
+ category,
743
+ prompt,
744
+ device)
745
+ # original mask: h,w,1 [0, 255]
746
+ original_mask = vlm_response_mask(
747
+ vlm_processor,
748
+ vlm_model,
749
+ category,
750
+ original_image,
751
+ prompt,
752
+ object_wait_for_edit,
753
+ sam,
754
+ sam_predictor,
755
+ sam_automask_generator,
756
+ groundingdino_model,
757
+ device).astype(np.uint8)
758
+ else:
759
+ original_mask = input_mask.astype(np.uint8)
760
+ category = None
761
+
762
+ ## resize mask if needed
763
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
764
+ if output_w == "" or output_h == "":
765
+ output_h, output_w = original_image.shape[:2]
766
+ if resize_default:
767
+ short_side = min(output_w, output_h)
768
+ scale_ratio = 640 / short_side
769
+ output_w = int(output_w * scale_ratio)
770
+ output_h = int(output_h * scale_ratio)
771
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
772
+ original_image = np.array(original_image)
773
+ if input_mask is not None:
774
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
775
+ input_mask = np.array(input_mask)
776
+ if original_mask is not None:
777
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
778
+ original_mask = np.array(original_mask)
779
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
780
+ else:
781
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
782
+ pass
783
+ else:
784
+ if resize_default:
785
+ short_side = min(output_w, output_h)
786
+ scale_ratio = 640 / short_side
787
+ output_w = int(output_w * scale_ratio)
788
+ output_h = int(output_h * scale_ratio)
789
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
790
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
791
+ original_image = np.array(original_image)
792
+ if input_mask is not None:
793
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
794
+ input_mask = np.array(input_mask)
795
+ if original_mask is not None:
796
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
797
+ original_mask = np.array(original_mask)
798
+
799
+
800
+ if original_mask.ndim == 2:
801
+ original_mask = original_mask[:,:,None]
802
+
803
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
804
+
805
+ masked_image = original_image * (1 - (original_mask>0))
806
+ masked_image = masked_image.astype(np.uint8)
807
+ masked_image = Image.fromarray(masked_image)
808
+
809
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
810
+
811
+
812
+ def process_random_mask(input_image,
813
+ original_image,
814
+ original_mask,
815
+ resize_default,
816
+ aspect_ratio_name,
817
+ ):
818
+
819
+ alpha_mask = input_image["layers"][0].split()[3]
820
+ input_mask = np.asarray(alpha_mask)
821
+
822
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
823
+ if output_w == "" or output_h == "":
824
+ output_h, output_w = original_image.shape[:2]
825
+ if resize_default:
826
+ short_side = min(output_w, output_h)
827
+ scale_ratio = 640 / short_side
828
+ output_w = int(output_w * scale_ratio)
829
+ output_h = int(output_h * scale_ratio)
830
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
831
+ original_image = np.array(original_image)
832
+ if input_mask is not None:
833
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
834
+ input_mask = np.array(input_mask)
835
+ if original_mask is not None:
836
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
837
+ original_mask = np.array(original_mask)
838
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
839
+ else:
840
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
841
+ pass
842
+ else:
843
+ if resize_default:
844
+ short_side = min(output_w, output_h)
845
+ scale_ratio = 640 / short_side
846
+ output_w = int(output_w * scale_ratio)
847
+ output_h = int(output_h * scale_ratio)
848
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
849
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
850
+ original_image = np.array(original_image)
851
+ if input_mask is not None:
852
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
853
+ input_mask = np.array(input_mask)
854
+ if original_mask is not None:
855
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
856
+ original_mask = np.array(original_mask)
857
+
858
+
859
+ if input_mask.max() == 0:
860
+ original_mask = original_mask
861
+ else:
862
+ original_mask = input_mask
863
+
864
+ if original_mask is None:
865
+ raise gr.Error('Please generate mask first')
866
+
867
+ if original_mask.ndim == 2:
868
+ original_mask = original_mask[:,:,None]
869
+
870
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
871
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
872
+
873
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
874
+
875
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
876
+ masked_image = masked_image.astype(original_image.dtype)
877
+ masked_image = Image.fromarray(masked_image)
878
+
879
+
880
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
881
+
882
+
883
+ def process_dilation_mask(input_image,
884
+ original_image,
885
+ original_mask,
886
+ resize_default,
887
+ aspect_ratio_name,
888
+ dilation_size=20):
889
+
890
+ alpha_mask = input_image["layers"][0].split()[3]
891
+ input_mask = np.asarray(alpha_mask)
892
+
893
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
894
+ if output_w == "" or output_h == "":
895
+ output_h, output_w = original_image.shape[:2]
896
+ if resize_default:
897
+ short_side = min(output_w, output_h)
898
+ scale_ratio = 640 / short_side
899
+ output_w = int(output_w * scale_ratio)
900
+ output_h = int(output_h * scale_ratio)
901
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
902
+ original_image = np.array(original_image)
903
+ if input_mask is not None:
904
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
905
+ input_mask = np.array(input_mask)
906
+ if original_mask is not None:
907
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
908
+ original_mask = np.array(original_mask)
909
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
910
+ else:
911
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
912
+ pass
913
+ else:
914
+ if resize_default:
915
+ short_side = min(output_w, output_h)
916
+ scale_ratio = 640 / short_side
917
+ output_w = int(output_w * scale_ratio)
918
+ output_h = int(output_h * scale_ratio)
919
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
920
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
921
+ original_image = np.array(original_image)
922
+ if input_mask is not None:
923
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
924
+ input_mask = np.array(input_mask)
925
+ if original_mask is not None:
926
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
927
+ original_mask = np.array(original_mask)
928
+
929
+ if input_mask.max() == 0:
930
+ original_mask = original_mask
931
+ else:
932
+ original_mask = input_mask
933
+
934
+ if original_mask is None:
935
+ raise gr.Error('Please generate mask first')
936
+
937
+ if original_mask.ndim == 2:
938
+ original_mask = original_mask[:,:,None]
939
+
940
+ dilation_type = np.random.choice(['square_dilation'])
941
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
942
+
943
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
944
+
945
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
946
+ masked_image = masked_image.astype(original_image.dtype)
947
+ masked_image = Image.fromarray(masked_image)
948
+
949
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
950
+
951
+
952
+ def process_erosion_mask(input_image,
953
+ original_image,
954
+ original_mask,
955
+ resize_default,
956
+ aspect_ratio_name,
957
+ dilation_size=20):
958
+ alpha_mask = input_image["layers"][0].split()[3]
959
+ input_mask = np.asarray(alpha_mask)
960
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
961
+ if output_w == "" or output_h == "":
962
+ output_h, output_w = original_image.shape[:2]
963
+ if resize_default:
964
+ short_side = min(output_w, output_h)
965
+ scale_ratio = 640 / short_side
966
+ output_w = int(output_w * scale_ratio)
967
+ output_h = int(output_h * scale_ratio)
968
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
969
+ original_image = np.array(original_image)
970
+ if input_mask is not None:
971
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
972
+ input_mask = np.array(input_mask)
973
+ if original_mask is not None:
974
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
975
+ original_mask = np.array(original_mask)
976
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
977
+ else:
978
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
979
+ pass
980
+ else:
981
+ if resize_default:
982
+ short_side = min(output_w, output_h)
983
+ scale_ratio = 640 / short_side
984
+ output_w = int(output_w * scale_ratio)
985
+ output_h = int(output_h * scale_ratio)
986
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
987
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
988
+ original_image = np.array(original_image)
989
+ if input_mask is not None:
990
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
991
+ input_mask = np.array(input_mask)
992
+ if original_mask is not None:
993
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
994
+ original_mask = np.array(original_mask)
995
+
996
+ if input_mask.max() == 0:
997
+ original_mask = original_mask
998
+ else:
999
+ original_mask = input_mask
1000
+
1001
+ if original_mask is None:
1002
+ raise gr.Error('Please generate mask first')
1003
+
1004
+ if original_mask.ndim == 2:
1005
+ original_mask = original_mask[:,:,None]
1006
+
1007
+ dilation_type = np.random.choice(['square_erosion'])
1008
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1009
+
1010
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1011
+
1012
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1013
+ masked_image = masked_image.astype(original_image.dtype)
1014
+ masked_image = Image.fromarray(masked_image)
1015
+
1016
+
1017
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1018
+
1019
+
1020
+ def move_mask_left(input_image,
1021
+ original_image,
1022
+ original_mask,
1023
+ moving_pixels,
1024
+ resize_default,
1025
+ aspect_ratio_name):
1026
+
1027
+ alpha_mask = input_image["layers"][0].split()[3]
1028
+ input_mask = np.asarray(alpha_mask)
1029
+
1030
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1031
+ if output_w == "" or output_h == "":
1032
+ output_h, output_w = original_image.shape[:2]
1033
+ if resize_default:
1034
+ short_side = min(output_w, output_h)
1035
+ scale_ratio = 640 / short_side
1036
+ output_w = int(output_w * scale_ratio)
1037
+ output_h = int(output_h * scale_ratio)
1038
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1039
+ original_image = np.array(original_image)
1040
+ if input_mask is not None:
1041
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1042
+ input_mask = np.array(input_mask)
1043
+ if original_mask is not None:
1044
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1045
+ original_mask = np.array(original_mask)
1046
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1047
+ else:
1048
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1049
+ pass
1050
+ else:
1051
+ if resize_default:
1052
+ short_side = min(output_w, output_h)
1053
+ scale_ratio = 640 / short_side
1054
+ output_w = int(output_w * scale_ratio)
1055
+ output_h = int(output_h * scale_ratio)
1056
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1057
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1058
+ original_image = np.array(original_image)
1059
+ if input_mask is not None:
1060
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1061
+ input_mask = np.array(input_mask)
1062
+ if original_mask is not None:
1063
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1064
+ original_mask = np.array(original_mask)
1065
+
1066
+ if input_mask.max() == 0:
1067
+ original_mask = original_mask
1068
+ else:
1069
+ original_mask = input_mask
1070
+
1071
+ if original_mask is None:
1072
+ raise gr.Error('Please generate mask first')
1073
+
1074
+ if original_mask.ndim == 2:
1075
+ original_mask = original_mask[:,:,None]
1076
+
1077
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1078
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1079
+
1080
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1081
+ masked_image = masked_image.astype(original_image.dtype)
1082
+ masked_image = Image.fromarray(masked_image)
1083
+
1084
+ if moved_mask.max() <= 1:
1085
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1086
+ original_mask = moved_mask
1087
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1088
+
1089
+
1090
+ def move_mask_right(input_image,
1091
+ original_image,
1092
+ original_mask,
1093
+ moving_pixels,
1094
+ resize_default,
1095
+ aspect_ratio_name):
1096
+ alpha_mask = input_image["layers"][0].split()[3]
1097
+ input_mask = np.asarray(alpha_mask)
1098
+
1099
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1100
+ if output_w == "" or output_h == "":
1101
+ output_h, output_w = original_image.shape[:2]
1102
+ if resize_default:
1103
+ short_side = min(output_w, output_h)
1104
+ scale_ratio = 640 / short_side
1105
+ output_w = int(output_w * scale_ratio)
1106
+ output_h = int(output_h * scale_ratio)
1107
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1108
+ original_image = np.array(original_image)
1109
+ if input_mask is not None:
1110
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1111
+ input_mask = np.array(input_mask)
1112
+ if original_mask is not None:
1113
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1114
+ original_mask = np.array(original_mask)
1115
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1116
+ else:
1117
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1118
+ pass
1119
+ else:
1120
+ if resize_default:
1121
+ short_side = min(output_w, output_h)
1122
+ scale_ratio = 640 / short_side
1123
+ output_w = int(output_w * scale_ratio)
1124
+ output_h = int(output_h * scale_ratio)
1125
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1126
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1127
+ original_image = np.array(original_image)
1128
+ if input_mask is not None:
1129
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1130
+ input_mask = np.array(input_mask)
1131
+ if original_mask is not None:
1132
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1133
+ original_mask = np.array(original_mask)
1134
+
1135
+ if input_mask.max() == 0:
1136
+ original_mask = original_mask
1137
+ else:
1138
+ original_mask = input_mask
1139
+
1140
+ if original_mask is None:
1141
+ raise gr.Error('Please generate mask first')
1142
+
1143
+ if original_mask.ndim == 2:
1144
+ original_mask = original_mask[:,:,None]
1145
+
1146
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1147
+
1148
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1149
+
1150
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1151
+ masked_image = masked_image.astype(original_image.dtype)
1152
+ masked_image = Image.fromarray(masked_image)
1153
+
1154
+
1155
+ if moved_mask.max() <= 1:
1156
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1157
+ original_mask = moved_mask
1158
+
1159
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1160
+
1161
+
1162
+ def move_mask_up(input_image,
1163
+ original_image,
1164
+ original_mask,
1165
+ moving_pixels,
1166
+ resize_default,
1167
+ aspect_ratio_name):
1168
+ alpha_mask = input_image["layers"][0].split()[3]
1169
+ input_mask = np.asarray(alpha_mask)
1170
+
1171
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1172
+ if output_w == "" or output_h == "":
1173
+ output_h, output_w = original_image.shape[:2]
1174
+ if resize_default:
1175
+ short_side = min(output_w, output_h)
1176
+ scale_ratio = 640 / short_side
1177
+ output_w = int(output_w * scale_ratio)
1178
+ output_h = int(output_h * scale_ratio)
1179
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1180
+ original_image = np.array(original_image)
1181
+ if input_mask is not None:
1182
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1183
+ input_mask = np.array(input_mask)
1184
+ if original_mask is not None:
1185
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1186
+ original_mask = np.array(original_mask)
1187
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1188
+ else:
1189
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1190
+ pass
1191
+ else:
1192
+ if resize_default:
1193
+ short_side = min(output_w, output_h)
1194
+ scale_ratio = 640 / short_side
1195
+ output_w = int(output_w * scale_ratio)
1196
+ output_h = int(output_h * scale_ratio)
1197
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1198
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1199
+ original_image = np.array(original_image)
1200
+ if input_mask is not None:
1201
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1202
+ input_mask = np.array(input_mask)
1203
+ if original_mask is not None:
1204
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1205
+ original_mask = np.array(original_mask)
1206
+
1207
+ if input_mask.max() == 0:
1208
+ original_mask = original_mask
1209
+ else:
1210
+ original_mask = input_mask
1211
+
1212
+ if original_mask is None:
1213
+ raise gr.Error('Please generate mask first')
1214
+
1215
+ if original_mask.ndim == 2:
1216
+ original_mask = original_mask[:,:,None]
1217
+
1218
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1219
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1220
+
1221
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1222
+ masked_image = masked_image.astype(original_image.dtype)
1223
+ masked_image = Image.fromarray(masked_image)
1224
+
1225
+ if moved_mask.max() <= 1:
1226
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1227
+ original_mask = moved_mask
1228
+
1229
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1230
+
1231
+
1232
+ def move_mask_down(input_image,
1233
+ original_image,
1234
+ original_mask,
1235
+ moving_pixels,
1236
+ resize_default,
1237
+ aspect_ratio_name):
1238
+ alpha_mask = input_image["layers"][0].split()[3]
1239
+ input_mask = np.asarray(alpha_mask)
1240
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1241
+ if output_w == "" or output_h == "":
1242
+ output_h, output_w = original_image.shape[:2]
1243
+ if resize_default:
1244
+ short_side = min(output_w, output_h)
1245
+ scale_ratio = 640 / short_side
1246
+ output_w = int(output_w * scale_ratio)
1247
+ output_h = int(output_h * scale_ratio)
1248
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1249
+ original_image = np.array(original_image)
1250
+ if input_mask is not None:
1251
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1252
+ input_mask = np.array(input_mask)
1253
+ if original_mask is not None:
1254
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1255
+ original_mask = np.array(original_mask)
1256
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1257
+ else:
1258
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1259
+ pass
1260
+ else:
1261
+ if resize_default:
1262
+ short_side = min(output_w, output_h)
1263
+ scale_ratio = 640 / short_side
1264
+ output_w = int(output_w * scale_ratio)
1265
+ output_h = int(output_h * scale_ratio)
1266
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1267
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1268
+ original_image = np.array(original_image)
1269
+ if input_mask is not None:
1270
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1271
+ input_mask = np.array(input_mask)
1272
+ if original_mask is not None:
1273
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1274
+ original_mask = np.array(original_mask)
1275
+
1276
+ if input_mask.max() == 0:
1277
+ original_mask = original_mask
1278
+ else:
1279
+ original_mask = input_mask
1280
+
1281
+ if original_mask is None:
1282
+ raise gr.Error('Please generate mask first')
1283
+
1284
+ if original_mask.ndim == 2:
1285
+ original_mask = original_mask[:,:,None]
1286
+
1287
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1288
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1289
+
1290
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1291
+ masked_image = masked_image.astype(original_image.dtype)
1292
+ masked_image = Image.fromarray(masked_image)
1293
+
1294
+ if moved_mask.max() <= 1:
1295
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1296
+ original_mask = moved_mask
1297
+
1298
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1299
+
1300
+
1301
+ def invert_mask(input_image,
1302
+ original_image,
1303
+ original_mask,
1304
+ ):
1305
+ alpha_mask = input_image["layers"][0].split()[3]
1306
+ input_mask = np.asarray(alpha_mask)
1307
+ if input_mask.max() == 0:
1308
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1309
+ else:
1310
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1311
+
1312
+ if original_mask is None:
1313
+ raise gr.Error('Please generate mask first')
1314
+
1315
+ original_mask = original_mask.squeeze()
1316
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1317
+
1318
+ if original_mask.ndim == 2:
1319
+ original_mask = original_mask[:,:,None]
1320
+
1321
+ if original_mask.max() <= 1:
1322
+ original_mask = (original_mask * 255).astype(np.uint8)
1323
+
1324
+ masked_image = original_image * (1 - (original_mask>0))
1325
+ masked_image = masked_image.astype(original_image.dtype)
1326
+ masked_image = Image.fromarray(masked_image)
1327
+
1328
+ return [masked_image], [mask_image], original_mask, True
1329
+
1330
+
1331
+ def init_img(base,
1332
+ init_type,
1333
+ prompt,
1334
+ aspect_ratio,
1335
+ example_change_times
1336
+ ):
1337
+ image_pil = base["background"].convert("RGB")
1338
+ original_image = np.array(image_pil)
1339
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1340
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1341
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1342
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1343
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1344
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1345
+ width, height = image_pil.size
1346
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1347
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1348
+ image_pil = image_pil.resize((width_new, height_new))
1349
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1350
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1351
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1352
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1353
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1354
+ else:
1355
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1356
+ aspect_ratio = "Custom resolution"
1357
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1358
+
1359
+
1360
+ def reset_func(input_image,
1361
+ original_image,
1362
+ original_mask,
1363
+ prompt,
1364
+ target_prompt,
1365
+ ):
1366
+ input_image = None
1367
+ original_image = None
1368
+ original_mask = None
1369
+ prompt = ''
1370
+ mask_gallery = []
1371
+ masked_gallery = []
1372
+ result_gallery = []
1373
+ target_prompt = ''
1374
+ if torch.cuda.is_available():
1375
+ torch.cuda.empty_cache()
1376
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1377
+
1378
+
1379
+ def update_example(example_type,
1380
+ prompt,
1381
+ example_change_times):
1382
+ input_image = INPUT_IMAGE_PATH[example_type]
1383
+ image_pil = Image.open(input_image).convert("RGB")
1384
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1385
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1386
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1387
+ width, height = image_pil.size
1388
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1389
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1390
+ image_pil = image_pil.resize((width_new, height_new))
1391
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1392
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1393
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1394
+
1395
+ original_image = np.array(image_pil)
1396
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1397
+ aspect_ratio = "Custom resolution"
1398
+ example_change_times += 1
1399
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1400
+
1401
+
1402
+ block = gr.Blocks(
1403
+ theme=gr.themes.Soft(
1404
+ radius_size=gr.themes.sizes.radius_none,
1405
+ text_size=gr.themes.sizes.text_md
1406
+ )
1407
+ )
1408
+ with block as demo:
1409
+ with gr.Row():
1410
+ with gr.Column():
1411
+ gr.HTML(head)
1412
+
1413
+ gr.Markdown(descriptions)
1414
+
1415
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1416
+ with gr.Row(equal_height=True):
1417
+ gr.Markdown(instructions)
1418
+
1419
+ original_image = gr.State(value=None)
1420
+ original_mask = gr.State(value=None)
1421
+ category = gr.State(value=None)
1422
+ status = gr.State(value=None)
1423
+ invert_mask_state = gr.State(value=False)
1424
+ example_change_times = gr.State(value=0)
1425
+
1426
+
1427
+ with gr.Row():
1428
+ with gr.Column():
1429
+ with gr.Row():
1430
+ input_image = gr.ImageEditor(
1431
+ label="Input Image",
1432
+ type="pil",
1433
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1434
+ layers = False,
1435
+ interactive=True,
1436
+ height=1024,
1437
+ sources=["upload"],
1438
+ placeholder="Please click here or the icon below to upload the image.",
1439
+ )
1440
+
1441
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1442
+ run_button = gr.Button("💫 Run")
1443
+
1444
+ vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1445
+ with gr.Group():
1446
+ with gr.Row():
1447
+ GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1448
+
1449
+ GPT4o_KEY_submit = gr.Button("Submit and Verify")
1450
+
1451
+
1452
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1453
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1454
+
1455
+ with gr.Row():
1456
+ mask_button = gr.Button("Generate Mask")
1457
+ random_mask_button = gr.Button("Square/Circle Mask ")
1458
+
1459
+
1460
+ with gr.Row():
1461
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1462
+
1463
+ target_prompt = gr.Text(
1464
+ label="Input Target Prompt",
1465
+ max_lines=5,
1466
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1467
+ value='',
1468
+ lines=2
1469
+ )
1470
+
1471
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1472
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1473
+ negative_prompt = gr.Text(
1474
+ label="Negative Prompt",
1475
+ max_lines=5,
1476
+ placeholder="Please input your negative prompt",
1477
+ value='ugly, low quality',lines=1
1478
+ )
1479
+
1480
+ control_strength = gr.Slider(
1481
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1482
+ )
1483
+ with gr.Group():
1484
+ seed = gr.Slider(
1485
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1486
+ )
1487
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1488
+
1489
+ blending = gr.Checkbox(label="Blending mode", value=True)
1490
+
1491
+
1492
+ num_samples = gr.Slider(
1493
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1494
+ )
1495
+
1496
+ with gr.Group():
1497
+ with gr.Row():
1498
+ guidance_scale = gr.Slider(
1499
+ label="Guidance scale",
1500
+ minimum=1,
1501
+ maximum=12,
1502
+ step=0.1,
1503
+ value=7.5,
1504
+ )
1505
+ num_inference_steps = gr.Slider(
1506
+ label="Number of inference steps",
1507
+ minimum=1,
1508
+ maximum=50,
1509
+ step=1,
1510
+ value=50,
1511
+ )
1512
+
1513
+
1514
+ with gr.Column():
1515
+ with gr.Row():
1516
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1517
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1518
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1519
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1520
+
1521
+ invert_mask_button = gr.Button("Invert Mask")
1522
+ dilation_size = gr.Slider(
1523
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1524
+ )
1525
+ with gr.Row():
1526
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1527
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1528
+
1529
+ moving_pixels = gr.Slider(
1530
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1531
+ )
1532
+ with gr.Row():
1533
+ move_left_button = gr.Button("Move Left")
1534
+ move_right_button = gr.Button("Move Right")
1535
+ with gr.Row():
1536
+ move_up_button = gr.Button("Move Up")
1537
+ move_down_button = gr.Button("Move Down")
1538
+
1539
+ with gr.Tab(elem_classes="feedback", label="Output"):
1540
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1541
+
1542
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1543
+
1544
+ reset_button = gr.Button("Reset")
1545
+
1546
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1547
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1548
+
1549
+
1550
+
1551
+ with gr.Row():
1552
+ example = gr.Examples(
1553
+ label="Quick Example",
1554
+ examples=EXAMPLES,
1555
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1556
+ examples_per_page=10,
1557
+ cache_examples=False,
1558
+ )
1559
+
1560
+
1561
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1562
+ with gr.Row(equal_height=True):
1563
+ gr.Markdown(tips)
1564
+
1565
+ with gr.Row():
1566
+ gr.Markdown(citation)
1567
+
1568
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1569
+ ## And we need to solve the conflict between the upload and change example functions.
1570
+ input_image.upload(
1571
+ init_img,
1572
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1573
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1574
+ )
1575
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1576
+
1577
+ ## vlm and base model dropdown
1578
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1579
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1580
+
1581
+
1582
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1583
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1584
+
1585
+
1586
+ ips=[input_image,
1587
+ original_image,
1588
+ original_mask,
1589
+ prompt,
1590
+ negative_prompt,
1591
+ control_strength,
1592
+ seed,
1593
+ randomize_seed,
1594
+ guidance_scale,
1595
+ num_inference_steps,
1596
+ num_samples,
1597
+ blending,
1598
+ category,
1599
+ target_prompt,
1600
+ resize_default,
1601
+ aspect_ratio,
1602
+ invert_mask_state]
1603
+
1604
+ ## run brushedit
1605
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1606
+
1607
+ ## mask func
1608
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1609
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1610
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1611
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1612
+
1613
+ ## move mask func
1614
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1615
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1616
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1617
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1618
+
1619
+ ## prompt func
1620
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1621
+
1622
+ ## reset func
1623
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1624
+
1625
+ ## if have a localhost access error, try to use the following code
1626
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1627
+ # demo.launch()
brushedit_app_gradio_new.py ADDED
The diff for this file is too large to render. See raw diff
 
brushedit_app_new.py ADDED
The diff for this file is too large to render. See raw diff
 
brushedit_app_new_0404_cirr_blip1.py ADDED
@@ -0,0 +1,2058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+ from pathlib import Path
9
+ import pandas as pd
10
+ import concurrent.futures
11
+ import faiss
12
+ import gradio as gr
13
+
14
+ from pathlib import Path
15
+ import os
16
+ import json
17
+
18
+ from PIL import Image
19
+
20
+ import torch.nn.functional as F # 新增此行
21
+ from huggingface_hub import hf_hub_download, snapshot_download
22
+ from scipy.ndimage import binary_dilation, binary_erosion
23
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
24
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
25
+
26
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
27
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
28
+ from diffusers.image_processor import VaeImageProcessor
29
+
30
+
31
+ from app.src.vlm_pipeline import (
32
+ vlm_response_editing_type,
33
+ vlm_response_object_wait_for_edit,
34
+ vlm_response_mask,
35
+ vlm_response_prompt_after_apply_instruction
36
+ )
37
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
38
+ from app.utils.utils import load_grounding_dino_model
39
+
40
+ from app.src.vlm_template import vlms_template
41
+ from app.src.base_model_template import base_models_template
42
+ from app.src.aspect_ratio_template import aspect_ratios
43
+
44
+ from openai import OpenAI
45
+ base_openai_url = "https://api.deepseek.com/"
46
+ base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
47
+
48
+
49
+ from transformers import BlipProcessor, BlipForConditionalGeneration
50
+ from transformers import CLIPProcessor, CLIPModel
51
+
52
+ from app.deepseek.instructions import (
53
+ create_apply_editing_messages_deepseek,
54
+ create_decomposed_query_messages_deepseek
55
+ )
56
+ from clip_retrieval.clip_client import ClipClient
57
+
58
+ #### Description ####
59
+ logo = r"""
60
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
61
+ """
62
+ head = r"""
63
+ <div style="text-align: center;">
64
+ <h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
65
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
66
+ <a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
67
+ <a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
68
+ <a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
69
+
70
+ </div>
71
+ </br>
72
+ </div>
73
+ """
74
+ descriptions = r"""
75
+ Demo for ZS-CIR"""
76
+
77
+ instructions = r"""
78
+ Demo for ZS-CIR"""
79
+
80
+ tips = r"""
81
+ Demo for ZS-CIR
82
+
83
+ """
84
+
85
+ citation = r"""
86
+ Demo for ZS-CIR"""
87
+
88
+ # - - - - - examples - - - - - #
89
+ EXAMPLES = [
90
+
91
+ [
92
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
93
+ "add a magic hat on frog head.",
94
+ 642087011,
95
+ "frog",
96
+ "frog",
97
+ True,
98
+ False,
99
+ "GPT4-o (Highly Recommended)"
100
+ ],
101
+ [
102
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
103
+ "replace the background to ancient China.",
104
+ 648464818,
105
+ "chinese_girl",
106
+ "chinese_girl",
107
+ True,
108
+ False,
109
+ "GPT4-o (Highly Recommended)"
110
+ ],
111
+ [
112
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
113
+ "remove the deer.",
114
+ 648464818,
115
+ "angel_christmas",
116
+ "angel_christmas",
117
+ False,
118
+ False,
119
+ "GPT4-o (Highly Recommended)"
120
+ ],
121
+ [
122
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
123
+ "add a wreath on head.",
124
+ 648464818,
125
+ "sunflower_girl",
126
+ "sunflower_girl",
127
+ True,
128
+ False,
129
+ "GPT4-o (Highly Recommended)"
130
+ ],
131
+ [
132
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
133
+ "add a butterfly fairy.",
134
+ 648464818,
135
+ "girl_on_sun",
136
+ "girl_on_sun",
137
+ True,
138
+ False,
139
+ "GPT4-o (Highly Recommended)"
140
+ ],
141
+ [
142
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
143
+ "remove the christmas hat.",
144
+ 642087011,
145
+ "spider_man_rm",
146
+ "spider_man_rm",
147
+ False,
148
+ False,
149
+ "GPT4-o (Highly Recommended)"
150
+ ],
151
+ [
152
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
153
+ "remove the flower.",
154
+ 642087011,
155
+ "anime_flower",
156
+ "anime_flower",
157
+ False,
158
+ False,
159
+ "GPT4-o (Highly Recommended)"
160
+ ],
161
+ [
162
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
163
+ "replace the clothes to a delicated floral skirt.",
164
+ 648464818,
165
+ "chenduling",
166
+ "chenduling",
167
+ True,
168
+ False,
169
+ "GPT4-o (Highly Recommended)"
170
+ ],
171
+ [
172
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
173
+ "make the hedgehog in Italy.",
174
+ 648464818,
175
+ "hedgehog_rp_bg",
176
+ "hedgehog_rp_bg",
177
+ True,
178
+ False,
179
+ "GPT4-o (Highly Recommended)"
180
+ ],
181
+
182
+ ]
183
+
184
+ INPUT_IMAGE_PATH = {
185
+ "frog": "./assets/frog/frog.jpeg",
186
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
187
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
188
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
189
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
190
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
191
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
192
+ "chenduling": "./assets/chenduling/chengduling.jpg",
193
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
194
+ }
195
+ MASK_IMAGE_PATH = {
196
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
197
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
198
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
199
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
200
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
201
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
202
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
203
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
204
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
205
+ }
206
+ MASKED_IMAGE_PATH = {
207
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
208
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
209
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
210
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
211
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
212
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
213
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
214
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
215
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
216
+ }
217
+ OUTPUT_IMAGE_PATH = {
218
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
219
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
220
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
221
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
222
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
223
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
224
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
225
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
226
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
227
+ }
228
+
229
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
230
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
231
+
232
+ VLM_MODEL_NAMES = list(vlms_template.keys())
233
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
234
+
235
+
236
+ BASE_MODELS = list(base_models_template.keys())
237
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
238
+
239
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
240
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
241
+
242
+
243
+ ## init device
244
+ try:
245
+ if torch.cuda.is_available():
246
+ device = "cuda"
247
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
248
+ device = "mps"
249
+ else:
250
+ device = "cpu"
251
+ except:
252
+ device = "cpu"
253
+
254
+ # ## init torch dtype
255
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
256
+ # torch_dtype = torch.bfloat16
257
+ # else:
258
+ # torch_dtype = torch.float16
259
+
260
+ # if device == "mps":
261
+ # torch_dtype = torch.float16
262
+
263
+ torch_dtype = torch.float16
264
+
265
+
266
+
267
+ # download hf models
268
+ BrushEdit_path = "models/"
269
+ if not os.path.exists(BrushEdit_path):
270
+ BrushEdit_path = snapshot_download(
271
+ repo_id="TencentARC/BrushEdit",
272
+ local_dir=BrushEdit_path,
273
+ token=os.getenv("HF_TOKEN"),
274
+ )
275
+
276
+ ## init default VLM
277
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
278
+ if vlm_processor != "" and vlm_model != "":
279
+ vlm_model.to(device)
280
+ else:
281
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
282
+
283
+ ## init default LLM
284
+ llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
285
+
286
+ ## init base model
287
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
288
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
289
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
290
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
291
+
292
+
293
+ # input brushnetX ckpt path
294
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
295
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
296
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
297
+ )
298
+ # speed up diffusion process with faster scheduler and memory optimization
299
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
300
+ # remove following line if xformers is not installed or when using Torch 2.0.
301
+ # pipe.enable_xformers_memory_efficient_attention()
302
+ pipe.enable_model_cpu_offload()
303
+
304
+
305
+ ## init SAM
306
+ sam = build_sam(checkpoint=sam_path)
307
+ sam.to(device=device)
308
+ sam_predictor = SamPredictor(sam)
309
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
310
+
311
+ ## init groundingdino_model
312
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
313
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
314
+
315
+ ## Ordinary function
316
+ def crop_and_resize(image: Image.Image,
317
+ target_width: int,
318
+ target_height: int) -> Image.Image:
319
+ """
320
+ Crops and resizes an image while preserving the aspect ratio.
321
+
322
+ Args:
323
+ image (Image.Image): Input PIL image to be cropped and resized.
324
+ target_width (int): Target width of the output image.
325
+ target_height (int): Target height of the output image.
326
+
327
+ Returns:
328
+ Image.Image: Cropped and resized image.
329
+ """
330
+ # Original dimensions
331
+ original_width, original_height = image.size
332
+ original_aspect = original_width / original_height
333
+ target_aspect = target_width / target_height
334
+
335
+ # Calculate crop box to maintain aspect ratio
336
+ if original_aspect > target_aspect:
337
+ # Crop horizontally
338
+ new_width = int(original_height * target_aspect)
339
+ new_height = original_height
340
+ left = (original_width - new_width) / 2
341
+ top = 0
342
+ right = left + new_width
343
+ bottom = original_height
344
+ else:
345
+ # Crop vertically
346
+ new_width = original_width
347
+ new_height = int(original_width / target_aspect)
348
+ left = 0
349
+ top = (original_height - new_height) / 2
350
+ right = original_width
351
+ bottom = top + new_height
352
+
353
+ # Crop and resize
354
+ cropped_image = image.crop((left, top, right, bottom))
355
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
356
+ return resized_image
357
+
358
+
359
+ ## Ordinary function
360
+ def resize(image: Image.Image,
361
+ target_width: int,
362
+ target_height: int) -> Image.Image:
363
+ """
364
+ Crops and resizes an image while preserving the aspect ratio.
365
+
366
+ Args:
367
+ image (Image.Image): Input PIL image to be cropped and resized.
368
+ target_width (int): Target width of the output image.
369
+ target_height (int): Target height of the output image.
370
+
371
+ Returns:
372
+ Image.Image: Cropped and resized image.
373
+ """
374
+ # Original dimensions
375
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
376
+ return resized_image
377
+
378
+
379
+ def move_mask_func(mask, direction, units):
380
+ binary_mask = mask.squeeze()>0
381
+ rows, cols = binary_mask.shape
382
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
383
+
384
+ if direction == 'down':
385
+ # move down
386
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
387
+
388
+ elif direction == 'up':
389
+ # move up
390
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
391
+
392
+ elif direction == 'right':
393
+ # move left
394
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
395
+
396
+ elif direction == 'left':
397
+ # move right
398
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
399
+
400
+ return moved_mask
401
+
402
+
403
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
404
+ # Randomly select the size of dilation
405
+ binary_mask = mask.squeeze()>0
406
+
407
+ if dilation_type == 'square_dilation':
408
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
409
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
410
+ elif dilation_type == 'square_erosion':
411
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
412
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
413
+ elif dilation_type == 'bounding_box':
414
+ # find the most left top and left bottom point
415
+ rows, cols = np.where(binary_mask)
416
+ if len(rows) == 0 or len(cols) == 0:
417
+ return mask # return original mask if no valid points
418
+
419
+ min_row = np.min(rows)
420
+ max_row = np.max(rows)
421
+ min_col = np.min(cols)
422
+ max_col = np.max(cols)
423
+
424
+ # create a bounding box
425
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
426
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
427
+
428
+ elif dilation_type == 'bounding_ellipse':
429
+ # find the most left top and left bottom point
430
+ rows, cols = np.where(binary_mask)
431
+ if len(rows) == 0 or len(cols) == 0:
432
+ return mask # return original mask if no valid points
433
+
434
+ min_row = np.min(rows)
435
+ max_row = np.max(rows)
436
+ min_col = np.min(cols)
437
+ max_col = np.max(cols)
438
+
439
+ # calculate the center and axis length of the ellipse
440
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
441
+ a = (max_col - min_col) // 2 # half long axis
442
+ b = (max_row - min_row) // 2 # half short axis
443
+
444
+ # create a bounding ellipse
445
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
446
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
447
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
448
+ dilated_mask[ellipse_mask] = True
449
+ else:
450
+ ValueError("dilation_type must be 'square' or 'ellipse'")
451
+
452
+ # use binary dilation
453
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
454
+ return dilated_mask
455
+
456
+
457
+ ## Gradio component function
458
+ def update_vlm_model(vlm_name):
459
+ global vlm_model, vlm_processor
460
+ if vlm_model is not None:
461
+ del vlm_model
462
+ torch.cuda.empty_cache()
463
+
464
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
465
+
466
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
467
+ if vlm_type == "llava-next":
468
+ if vlm_processor != "" and vlm_model != "":
469
+ vlm_model.to(device)
470
+ return vlm_model_dropdown
471
+ else:
472
+ if os.path.exists(vlm_local_path):
473
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
474
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
475
+ else:
476
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
477
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
478
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
479
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
480
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
481
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
482
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
483
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
484
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
485
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
486
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
487
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
488
+ elif vlm_name == "llava-next-72b-hf (Preload)":
489
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
490
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
491
+ elif vlm_type == "qwen2-vl":
492
+ if vlm_processor != "" and vlm_model != "":
493
+ vlm_model.to(device)
494
+ return vlm_model_dropdown
495
+ else:
496
+ if os.path.exists(vlm_local_path):
497
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
498
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
499
+ else:
500
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
501
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
502
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
503
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
504
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
505
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
506
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
507
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
508
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
509
+ elif vlm_type == "openai":
510
+ pass
511
+ return "success"
512
+
513
+
514
+ def update_base_model(base_model_name):
515
+ global pipe
516
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
517
+ if pipe is not None:
518
+ del pipe
519
+ torch.cuda.empty_cache()
520
+ base_model_path, pipe = base_models_template[base_model_name]
521
+ if pipe != "":
522
+ pipe.to(device)
523
+ else:
524
+ if os.path.exists(base_model_path):
525
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
526
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
527
+ )
528
+ # pipe.enable_xformers_memory_efficient_attention()
529
+ pipe.enable_model_cpu_offload()
530
+ else:
531
+ raise gr.Error(f"The base model {base_model_name} does not exist")
532
+ return "success"
533
+
534
+
535
+ def process_random_mask(input_image,
536
+ original_image,
537
+ original_mask,
538
+ resize_default,
539
+ aspect_ratio_name,
540
+ ):
541
+
542
+ alpha_mask = input_image["layers"][0].split()[3]
543
+ input_mask = np.asarray(alpha_mask)
544
+
545
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
546
+ if output_w == "" or output_h == "":
547
+ output_h, output_w = original_image.shape[:2]
548
+ if resize_default:
549
+ short_side = min(output_w, output_h)
550
+ scale_ratio = 640 / short_side
551
+ output_w = int(output_w * scale_ratio)
552
+ output_h = int(output_h * scale_ratio)
553
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
554
+ original_image = np.array(original_image)
555
+ if input_mask is not None:
556
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
557
+ input_mask = np.array(input_mask)
558
+ if original_mask is not None:
559
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
560
+ original_mask = np.array(original_mask)
561
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
562
+ else:
563
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
564
+ pass
565
+ else:
566
+ if resize_default:
567
+ short_side = min(output_w, output_h)
568
+ scale_ratio = 640 / short_side
569
+ output_w = int(output_w * scale_ratio)
570
+ output_h = int(output_h * scale_ratio)
571
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
572
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
573
+ original_image = np.array(original_image)
574
+ if input_mask is not None:
575
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
576
+ input_mask = np.array(input_mask)
577
+ if original_mask is not None:
578
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
579
+ original_mask = np.array(original_mask)
580
+
581
+
582
+ if input_mask.max() == 0:
583
+ original_mask = original_mask
584
+ else:
585
+ original_mask = input_mask
586
+
587
+ if original_mask is None:
588
+ raise gr.Error('Please generate mask first')
589
+
590
+ if original_mask.ndim == 2:
591
+ original_mask = original_mask[:,:,None]
592
+
593
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
594
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
595
+
596
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
597
+
598
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
599
+ masked_image = masked_image.astype(original_image.dtype)
600
+ masked_image = Image.fromarray(masked_image)
601
+
602
+
603
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
604
+
605
+
606
+ def process_dilation_mask(input_image,
607
+ original_image,
608
+ original_mask,
609
+ resize_default,
610
+ aspect_ratio_name,
611
+ dilation_size=20):
612
+
613
+ alpha_mask = input_image["layers"][0].split()[3]
614
+ input_mask = np.asarray(alpha_mask)
615
+
616
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
617
+ if output_w == "" or output_h == "":
618
+ output_h, output_w = original_image.shape[:2]
619
+ if resize_default:
620
+ short_side = min(output_w, output_h)
621
+ scale_ratio = 640 / short_side
622
+ output_w = int(output_w * scale_ratio)
623
+ output_h = int(output_h * scale_ratio)
624
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
625
+ original_image = np.array(original_image)
626
+ if input_mask is not None:
627
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
628
+ input_mask = np.array(input_mask)
629
+ if original_mask is not None:
630
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
631
+ original_mask = np.array(original_mask)
632
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
633
+ else:
634
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
635
+ pass
636
+ else:
637
+ if resize_default:
638
+ short_side = min(output_w, output_h)
639
+ scale_ratio = 640 / short_side
640
+ output_w = int(output_w * scale_ratio)
641
+ output_h = int(output_h * scale_ratio)
642
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
643
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
644
+ original_image = np.array(original_image)
645
+ if input_mask is not None:
646
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
647
+ input_mask = np.array(input_mask)
648
+ if original_mask is not None:
649
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
650
+ original_mask = np.array(original_mask)
651
+
652
+ if input_mask.max() == 0:
653
+ original_mask = original_mask
654
+ else:
655
+ original_mask = input_mask
656
+
657
+ if original_mask is None:
658
+ raise gr.Error('Please generate mask first')
659
+
660
+ if original_mask.ndim == 2:
661
+ original_mask = original_mask[:,:,None]
662
+
663
+ dilation_type = np.random.choice(['square_dilation'])
664
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
665
+
666
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
667
+
668
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
669
+ masked_image = masked_image.astype(original_image.dtype)
670
+ masked_image = Image.fromarray(masked_image)
671
+
672
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
673
+
674
+
675
+ def process_erosion_mask(input_image,
676
+ original_image,
677
+ original_mask,
678
+ resize_default,
679
+ aspect_ratio_name,
680
+ dilation_size=20):
681
+ alpha_mask = input_image["layers"][0].split()[3]
682
+ input_mask = np.asarray(alpha_mask)
683
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
684
+ if output_w == "" or output_h == "":
685
+ output_h, output_w = original_image.shape[:2]
686
+ if resize_default:
687
+ short_side = min(output_w, output_h)
688
+ scale_ratio = 640 / short_side
689
+ output_w = int(output_w * scale_ratio)
690
+ output_h = int(output_h * scale_ratio)
691
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
692
+ original_image = np.array(original_image)
693
+ if input_mask is not None:
694
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
695
+ input_mask = np.array(input_mask)
696
+ if original_mask is not None:
697
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
698
+ original_mask = np.array(original_mask)
699
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
700
+ else:
701
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
702
+ pass
703
+ else:
704
+ if resize_default:
705
+ short_side = min(output_w, output_h)
706
+ scale_ratio = 640 / short_side
707
+ output_w = int(output_w * scale_ratio)
708
+ output_h = int(output_h * scale_ratio)
709
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
710
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
711
+ original_image = np.array(original_image)
712
+ if input_mask is not None:
713
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
714
+ input_mask = np.array(input_mask)
715
+ if original_mask is not None:
716
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
717
+ original_mask = np.array(original_mask)
718
+
719
+ if input_mask.max() == 0:
720
+ original_mask = original_mask
721
+ else:
722
+ original_mask = input_mask
723
+
724
+ if original_mask is None:
725
+ raise gr.Error('Please generate mask first')
726
+
727
+ if original_mask.ndim == 2:
728
+ original_mask = original_mask[:,:,None]
729
+
730
+ dilation_type = np.random.choice(['square_erosion'])
731
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
732
+
733
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
734
+
735
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
736
+ masked_image = masked_image.astype(original_image.dtype)
737
+ masked_image = Image.fromarray(masked_image)
738
+
739
+
740
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
741
+
742
+
743
+ def move_mask_left(input_image,
744
+ original_image,
745
+ original_mask,
746
+ moving_pixels,
747
+ resize_default,
748
+ aspect_ratio_name):
749
+
750
+ alpha_mask = input_image["layers"][0].split()[3]
751
+ input_mask = np.asarray(alpha_mask)
752
+
753
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
754
+ if output_w == "" or output_h == "":
755
+ output_h, output_w = original_image.shape[:2]
756
+ if resize_default:
757
+ short_side = min(output_w, output_h)
758
+ scale_ratio = 640 / short_side
759
+ output_w = int(output_w * scale_ratio)
760
+ output_h = int(output_h * scale_ratio)
761
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
762
+ original_image = np.array(original_image)
763
+ if input_mask is not None:
764
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
765
+ input_mask = np.array(input_mask)
766
+ if original_mask is not None:
767
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
768
+ original_mask = np.array(original_mask)
769
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
770
+ else:
771
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
772
+ pass
773
+ else:
774
+ if resize_default:
775
+ short_side = min(output_w, output_h)
776
+ scale_ratio = 640 / short_side
777
+ output_w = int(output_w * scale_ratio)
778
+ output_h = int(output_h * scale_ratio)
779
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
780
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
781
+ original_image = np.array(original_image)
782
+ if input_mask is not None:
783
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
784
+ input_mask = np.array(input_mask)
785
+ if original_mask is not None:
786
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
787
+ original_mask = np.array(original_mask)
788
+
789
+ if input_mask.max() == 0:
790
+ original_mask = original_mask
791
+ else:
792
+ original_mask = input_mask
793
+
794
+ if original_mask is None:
795
+ raise gr.Error('Please generate mask first')
796
+
797
+ if original_mask.ndim == 2:
798
+ original_mask = original_mask[:,:,None]
799
+
800
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
801
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
802
+
803
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
804
+ masked_image = masked_image.astype(original_image.dtype)
805
+ masked_image = Image.fromarray(masked_image)
806
+
807
+ if moved_mask.max() <= 1:
808
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
809
+ original_mask = moved_mask
810
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
811
+
812
+
813
+ def move_mask_right(input_image,
814
+ original_image,
815
+ original_mask,
816
+ moving_pixels,
817
+ resize_default,
818
+ aspect_ratio_name):
819
+ alpha_mask = input_image["layers"][0].split()[3]
820
+ input_mask = np.asarray(alpha_mask)
821
+
822
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
823
+ if output_w == "" or output_h == "":
824
+ output_h, output_w = original_image.shape[:2]
825
+ if resize_default:
826
+ short_side = min(output_w, output_h)
827
+ scale_ratio = 640 / short_side
828
+ output_w = int(output_w * scale_ratio)
829
+ output_h = int(output_h * scale_ratio)
830
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
831
+ original_image = np.array(original_image)
832
+ if input_mask is not None:
833
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
834
+ input_mask = np.array(input_mask)
835
+ if original_mask is not None:
836
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
837
+ original_mask = np.array(original_mask)
838
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
839
+ else:
840
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
841
+ pass
842
+ else:
843
+ if resize_default:
844
+ short_side = min(output_w, output_h)
845
+ scale_ratio = 640 / short_side
846
+ output_w = int(output_w * scale_ratio)
847
+ output_h = int(output_h * scale_ratio)
848
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
849
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
850
+ original_image = np.array(original_image)
851
+ if input_mask is not None:
852
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
853
+ input_mask = np.array(input_mask)
854
+ if original_mask is not None:
855
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
856
+ original_mask = np.array(original_mask)
857
+
858
+ if input_mask.max() == 0:
859
+ original_mask = original_mask
860
+ else:
861
+ original_mask = input_mask
862
+
863
+ if original_mask is None:
864
+ raise gr.Error('Please generate mask first')
865
+
866
+ if original_mask.ndim == 2:
867
+ original_mask = original_mask[:,:,None]
868
+
869
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
870
+
871
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
872
+
873
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
874
+ masked_image = masked_image.astype(original_image.dtype)
875
+ masked_image = Image.fromarray(masked_image)
876
+
877
+
878
+ if moved_mask.max() <= 1:
879
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
880
+ original_mask = moved_mask
881
+
882
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
883
+
884
+
885
+ def move_mask_up(input_image,
886
+ original_image,
887
+ original_mask,
888
+ moving_pixels,
889
+ resize_default,
890
+ aspect_ratio_name):
891
+ alpha_mask = input_image["layers"][0].split()[3]
892
+ input_mask = np.asarray(alpha_mask)
893
+
894
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
895
+ if output_w == "" or output_h == "":
896
+ output_h, output_w = original_image.shape[:2]
897
+ if resize_default:
898
+ short_side = min(output_w, output_h)
899
+ scale_ratio = 640 / short_side
900
+ output_w = int(output_w * scale_ratio)
901
+ output_h = int(output_h * scale_ratio)
902
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
903
+ original_image = np.array(original_image)
904
+ if input_mask is not None:
905
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
906
+ input_mask = np.array(input_mask)
907
+ if original_mask is not None:
908
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
909
+ original_mask = np.array(original_mask)
910
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
911
+ else:
912
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
913
+ pass
914
+ else:
915
+ if resize_default:
916
+ short_side = min(output_w, output_h)
917
+ scale_ratio = 640 / short_side
918
+ output_w = int(output_w * scale_ratio)
919
+ output_h = int(output_h * scale_ratio)
920
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
921
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
922
+ original_image = np.array(original_image)
923
+ if input_mask is not None:
924
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
925
+ input_mask = np.array(input_mask)
926
+ if original_mask is not None:
927
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
928
+ original_mask = np.array(original_mask)
929
+
930
+ if input_mask.max() == 0:
931
+ original_mask = original_mask
932
+ else:
933
+ original_mask = input_mask
934
+
935
+ if original_mask is None:
936
+ raise gr.Error('Please generate mask first')
937
+
938
+ if original_mask.ndim == 2:
939
+ original_mask = original_mask[:,:,None]
940
+
941
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
942
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
943
+
944
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
945
+ masked_image = masked_image.astype(original_image.dtype)
946
+ masked_image = Image.fromarray(masked_image)
947
+
948
+ if moved_mask.max() <= 1:
949
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
950
+ original_mask = moved_mask
951
+
952
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
953
+
954
+
955
+ def move_mask_down(input_image,
956
+ original_image,
957
+ original_mask,
958
+ moving_pixels,
959
+ resize_default,
960
+ aspect_ratio_name):
961
+ alpha_mask = input_image["layers"][0].split()[3]
962
+ input_mask = np.asarray(alpha_mask)
963
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
964
+ if output_w == "" or output_h == "":
965
+ output_h, output_w = original_image.shape[:2]
966
+ if resize_default:
967
+ short_side = min(output_w, output_h)
968
+ scale_ratio = 640 / short_side
969
+ output_w = int(output_w * scale_ratio)
970
+ output_h = int(output_h * scale_ratio)
971
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
972
+ original_image = np.array(original_image)
973
+ if input_mask is not None:
974
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
975
+ input_mask = np.array(input_mask)
976
+ if original_mask is not None:
977
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
978
+ original_mask = np.array(original_mask)
979
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
980
+ else:
981
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
982
+ pass
983
+ else:
984
+ if resize_default:
985
+ short_side = min(output_w, output_h)
986
+ scale_ratio = 640 / short_side
987
+ output_w = int(output_w * scale_ratio)
988
+ output_h = int(output_h * scale_ratio)
989
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
990
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
991
+ original_image = np.array(original_image)
992
+ if input_mask is not None:
993
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
994
+ input_mask = np.array(input_mask)
995
+ if original_mask is not None:
996
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
997
+ original_mask = np.array(original_mask)
998
+
999
+ if input_mask.max() == 0:
1000
+ original_mask = original_mask
1001
+ else:
1002
+ original_mask = input_mask
1003
+
1004
+ if original_mask is None:
1005
+ raise gr.Error('Please generate mask first')
1006
+
1007
+ if original_mask.ndim == 2:
1008
+ original_mask = original_mask[:,:,None]
1009
+
1010
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1011
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1012
+
1013
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1014
+ masked_image = masked_image.astype(original_image.dtype)
1015
+ masked_image = Image.fromarray(masked_image)
1016
+
1017
+ if moved_mask.max() <= 1:
1018
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1019
+ original_mask = moved_mask
1020
+
1021
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1022
+
1023
+
1024
+ def invert_mask(input_image,
1025
+ original_image,
1026
+ original_mask,
1027
+ ):
1028
+ alpha_mask = input_image["layers"][0].split()[3]
1029
+ input_mask = np.asarray(alpha_mask)
1030
+ if input_mask.max() == 0:
1031
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1032
+ else:
1033
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1034
+
1035
+ if original_mask is None:
1036
+ raise gr.Error('Please generate mask first')
1037
+
1038
+ original_mask = original_mask.squeeze()
1039
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1040
+
1041
+ if original_mask.ndim == 2:
1042
+ original_mask = original_mask[:,:,None]
1043
+
1044
+ if original_mask.max() <= 1:
1045
+ original_mask = (original_mask * 255).astype(np.uint8)
1046
+
1047
+ masked_image = original_image * (1 - (original_mask>0))
1048
+ masked_image = masked_image.astype(original_image.dtype)
1049
+ masked_image = Image.fromarray(masked_image)
1050
+
1051
+ return [masked_image], [mask_image], original_mask, True
1052
+
1053
+
1054
+
1055
+ def reset_func(input_image,
1056
+ original_image,
1057
+ original_mask,
1058
+ prompt,
1059
+ target_prompt,
1060
+ ):
1061
+ input_image = None
1062
+ original_image = None
1063
+ original_mask = None
1064
+ prompt = ''
1065
+ mask_gallery = []
1066
+ masked_gallery = []
1067
+ result_gallery = []
1068
+ target_prompt = ''
1069
+ if torch.cuda.is_available():
1070
+ torch.cuda.empty_cache()
1071
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1072
+
1073
+
1074
+ def update_example(example_type,
1075
+ prompt,
1076
+ example_change_times):
1077
+ input_image = INPUT_IMAGE_PATH[example_type]
1078
+ image_pil = Image.open(input_image).convert("RGB")
1079
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1080
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1081
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1082
+ width, height = image_pil.size
1083
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1084
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1085
+ image_pil = image_pil.resize((width_new, height_new))
1086
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1087
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1088
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1089
+
1090
+ original_image = np.array(image_pil)
1091
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1092
+ aspect_ratio = "Custom resolution"
1093
+ example_change_times += 1
1094
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1095
+
1096
+
1097
+ def generate_target_prompt(input_image,
1098
+ original_image,
1099
+ prompt):
1100
+ # load example image
1101
+ if isinstance(original_image, str):
1102
+ original_image = input_image
1103
+
1104
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1105
+ vlm_processor,
1106
+ vlm_model,
1107
+ original_image,
1108
+ prompt,
1109
+ device)
1110
+ return prompt_after_apply_instruction
1111
+
1112
+
1113
+
1114
+ from app.utils.utils import generate_caption
1115
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
1116
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
1117
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
1118
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
1119
+
1120
+
1121
+ def init_img(base,
1122
+ init_type,
1123
+ prompt,
1124
+ aspect_ratio,
1125
+ example_change_times
1126
+ ):
1127
+ image_pil = base["background"].convert("RGB")
1128
+ original_image = np.array(image_pil)
1129
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1130
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1131
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1132
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1133
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1134
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1135
+ width, height = image_pil.size
1136
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1137
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1138
+ image_pil = image_pil.resize((width_new, height_new))
1139
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1140
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1141
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1142
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1143
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1144
+ else:
1145
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1146
+ aspect_ratio = "Custom resolution"
1147
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1148
+
1149
+
1150
+ def process_mask(input_image,
1151
+ original_image,
1152
+ prompt,
1153
+ resize_default,
1154
+ aspect_ratio_name):
1155
+ if original_image is None:
1156
+ raise gr.Error('Please upload the input image')
1157
+ if prompt is None:
1158
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
1159
+
1160
+ ## load mask
1161
+ alpha_mask = input_image["layers"][0].split()[3]
1162
+ input_mask = np.array(alpha_mask)
1163
+
1164
+ # load example image
1165
+ if isinstance(original_image, str):
1166
+ original_image = input_image["background"]
1167
+
1168
+ if input_mask.max() == 0:
1169
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
1170
+
1171
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
1172
+ vlm_model,
1173
+ original_image,
1174
+ category,
1175
+ prompt,
1176
+ device)
1177
+ # original mask: h,w,1 [0, 255]
1178
+ original_mask = vlm_response_mask(
1179
+ vlm_processor,
1180
+ vlm_model,
1181
+ category,
1182
+ original_image,
1183
+ prompt,
1184
+ object_wait_for_edit,
1185
+ sam,
1186
+ sam_predictor,
1187
+ sam_automask_generator,
1188
+ groundingdino_model,
1189
+ device).astype(np.uint8)
1190
+ else:
1191
+ original_mask = input_mask.astype(np.uint8)
1192
+ category = None
1193
+
1194
+ ## resize mask if needed
1195
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1196
+ if output_w == "" or output_h == "":
1197
+ output_h, output_w = original_image.shape[:2]
1198
+ if resize_default:
1199
+ short_side = min(output_w, output_h)
1200
+ scale_ratio = 640 / short_side
1201
+ output_w = int(output_w * scale_ratio)
1202
+ output_h = int(output_h * scale_ratio)
1203
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1204
+ original_image = np.array(original_image)
1205
+ if input_mask is not None:
1206
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1207
+ input_mask = np.array(input_mask)
1208
+ if original_mask is not None:
1209
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1210
+ original_mask = np.array(original_mask)
1211
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1212
+ else:
1213
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1214
+ pass
1215
+ else:
1216
+ if resize_default:
1217
+ short_side = min(output_w, output_h)
1218
+ scale_ratio = 640 / short_side
1219
+ output_w = int(output_w * scale_ratio)
1220
+ output_h = int(output_h * scale_ratio)
1221
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1222
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1223
+ original_image = np.array(original_image)
1224
+ if input_mask is not None:
1225
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1226
+ input_mask = np.array(input_mask)
1227
+ if original_mask is not None:
1228
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1229
+ original_mask = np.array(original_mask)
1230
+
1231
+
1232
+ if original_mask.ndim == 2:
1233
+ original_mask = original_mask[:,:,None]
1234
+
1235
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
1236
+
1237
+ masked_image = original_image * (1 - (original_mask>0))
1238
+ masked_image = masked_image.astype(np.uint8)
1239
+ masked_image = Image.fromarray(masked_image)
1240
+
1241
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
1242
+
1243
+
1244
+ def process(input_image,
1245
+ original_image,
1246
+ original_mask,
1247
+ prompt,
1248
+ negative_prompt,
1249
+ control_strength,
1250
+ seed,
1251
+ randomize_seed,
1252
+ guidance_scale,
1253
+ num_inference_steps,
1254
+ num_samples,
1255
+ blending,
1256
+ category,
1257
+ target_prompt,
1258
+ resize_default,
1259
+ aspect_ratio_name,
1260
+ invert_mask_state):
1261
+ if original_image is None:
1262
+ if input_image is None:
1263
+ raise gr.Error('Please upload the input image')
1264
+ else:
1265
+ image_pil = input_image["background"].convert("RGB")
1266
+ original_image = np.array(image_pil)
1267
+ if prompt is None or prompt == "":
1268
+ if target_prompt is None or target_prompt == "":
1269
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
1270
+
1271
+ alpha_mask = input_image["layers"][0].split()[3]
1272
+ input_mask = np.asarray(alpha_mask)
1273
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1274
+ if output_w == "" or output_h == "":
1275
+ output_h, output_w = original_image.shape[:2]
1276
+
1277
+ if resize_default:
1278
+ short_side = min(output_w, output_h)
1279
+ scale_ratio = 640 / short_side
1280
+ output_w = int(output_w * scale_ratio)
1281
+ output_h = int(output_h * scale_ratio)
1282
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1283
+ original_image = np.array(original_image)
1284
+ if input_mask is not None:
1285
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1286
+ input_mask = np.array(input_mask)
1287
+ if original_mask is not None:
1288
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1289
+ original_mask = np.array(original_mask)
1290
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1291
+ else:
1292
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1293
+ pass
1294
+ else:
1295
+ if resize_default:
1296
+ short_side = min(output_w, output_h)
1297
+ scale_ratio = 640 / short_side
1298
+ output_w = int(output_w * scale_ratio)
1299
+ output_h = int(output_h * scale_ratio)
1300
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1301
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1302
+ original_image = np.array(original_image)
1303
+ if input_mask is not None:
1304
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1305
+ input_mask = np.array(input_mask)
1306
+ if original_mask is not None:
1307
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1308
+ original_mask = np.array(original_mask)
1309
+
1310
+ if invert_mask_state:
1311
+ original_mask = original_mask
1312
+ else:
1313
+ if input_mask.max() == 0:
1314
+ original_mask = original_mask
1315
+ else:
1316
+ original_mask = input_mask
1317
+
1318
+
1319
+ # inpainting directly if target_prompt is not None
1320
+ if category is not None:
1321
+ pass
1322
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
1323
+ pass
1324
+ else:
1325
+ try:
1326
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
1327
+ except Exception as e:
1328
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1329
+
1330
+
1331
+ if original_mask is not None:
1332
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
1333
+ else:
1334
+ try:
1335
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
1336
+ vlm_processor,
1337
+ vlm_model,
1338
+ original_image,
1339
+ category,
1340
+ prompt,
1341
+ device)
1342
+
1343
+ original_mask = vlm_response_mask(vlm_processor,
1344
+ vlm_model,
1345
+ category,
1346
+ original_image,
1347
+ prompt,
1348
+ object_wait_for_edit,
1349
+ sam,
1350
+ sam_predictor,
1351
+ sam_automask_generator,
1352
+ groundingdino_model,
1353
+ device).astype(np.uint8)
1354
+ except Exception as e:
1355
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1356
+
1357
+ if original_mask.ndim == 2:
1358
+ original_mask = original_mask[:,:,None]
1359
+
1360
+
1361
+ if target_prompt is not None and len(target_prompt) >= 1:
1362
+ prompt_after_apply_instruction = target_prompt
1363
+
1364
+ else:
1365
+ try:
1366
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1367
+ vlm_processor,
1368
+ vlm_model,
1369
+ original_image,
1370
+ prompt,
1371
+ device)
1372
+ except Exception as e:
1373
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1374
+
1375
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
1376
+
1377
+
1378
+ with torch.autocast(device):
1379
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
1380
+ prompt_after_apply_instruction,
1381
+ original_mask,
1382
+ original_image,
1383
+ generator,
1384
+ num_inference_steps,
1385
+ guidance_scale,
1386
+ control_strength,
1387
+ negative_prompt,
1388
+ num_samples,
1389
+ blending)
1390
+ original_image = np.array(init_image_np)
1391
+ masked_image = original_image * (1 - (mask_np>0))
1392
+ masked_image = masked_image.astype(np.uint8)
1393
+ masked_image = Image.fromarray(masked_image)
1394
+
1395
+ return image, [mask_image], [masked_image], prompt, '', False
1396
+
1397
+
1398
+ def process_cirr_images():
1399
+ # 初始化VLM/SAM模型(需补充实际加载代码)
1400
+ global vlm_model, sam_predictor, groundingdino_model
1401
+ if not all([vlm_model, sam_predictor, groundingdino_model]):
1402
+ raise RuntimeError("Required models not initialized")
1403
+
1404
+ # Define paths
1405
+ dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
1406
+ cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
1407
+ output_dirs = {
1408
+ "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_edited"),
1409
+ "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_mask"),
1410
+ "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_masked")
1411
+ }
1412
+
1413
+ # Create output directories
1414
+ for dir_path in output_dirs.values():
1415
+ dir_path.mkdir(parents=True, exist_ok=True)
1416
+
1417
+ # Load captions
1418
+ with open(cap_file, 'r') as f:
1419
+ captions = json.load(f)
1420
+
1421
+ descriptions = {}
1422
+
1423
+ for img_path in dev_dir.glob("*.png"):
1424
+ base_name = img_path.stem
1425
+ caption = next((item["caption"] for item in captions if item.get("reference") == base_name), None)
1426
+
1427
+ if not caption:
1428
+ print(f"Warning: No caption for {base_name}")
1429
+ continue
1430
+
1431
+ try:
1432
+ # 关键修改1:构造空alpha通道(全0)
1433
+ rgb_image = Image.open(img_path).convert("RGB")
1434
+ empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
1435
+ image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
1436
+
1437
+ # 关键修改2:调用init_img初始化
1438
+ base = {"background": image, "layers": [image]}
1439
+ init_results = init_img(
1440
+ base=base,
1441
+ init_type="custom", # 使用自定义初始化
1442
+ prompt=caption,
1443
+ aspect_ratio="Custom resolution",
1444
+ example_change_times=0
1445
+ )
1446
+
1447
+ # 获取初始化后的参数
1448
+ input_image = init_results[0]
1449
+ original_image = init_results[1]
1450
+ original_mask = init_results[2]
1451
+
1452
+ # 关键修改3:正确设置process参数
1453
+ process_results = process(
1454
+ input_image=input_image,
1455
+ original_image=original_image,
1456
+ original_mask=original_mask, # 传递初始化后的mask
1457
+ prompt=caption,
1458
+ negative_prompt="ugly, low quality",
1459
+ control_strength=1.0,
1460
+ seed=648464818,
1461
+ randomize_seed=False,
1462
+ guidance_scale=7.5,
1463
+ num_inference_steps=50,
1464
+ num_samples=1,
1465
+ blending=True,
1466
+ category=None,
1467
+ target_prompt="",
1468
+ resize_default=True,
1469
+ aspect_ratio_name="Custom resolution",
1470
+ invert_mask_state=False
1471
+ )
1472
+
1473
+ # 结果处理(保持原有逻辑)
1474
+ result_images, mask_images, masked_images = process_results[:3]
1475
+
1476
+ # Save images
1477
+ output_dirs["edited"].mkdir(exist_ok=True)
1478
+ result_images[0].save(output_dirs["edited"] / f"{base_name}.png")
1479
+ mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.png")
1480
+ masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.png")
1481
+
1482
+ # Generate BLIP description
1483
+ blip_desc, _ = generate_blip_description({"background": image})
1484
+ descriptions[base_name] = {
1485
+ "original_caption": caption,
1486
+ "blip_description": blip_desc
1487
+ }
1488
+
1489
+ print(f"Processed {base_name}")
1490
+
1491
+ except Exception as e:
1492
+ print(f"Error processing {base_name}: {str(e)}")
1493
+ continue
1494
+
1495
+ # Save descriptions
1496
+ with open("/home/zt/data/BrushEdit/cirr/cirr_description_fix.json", 'w') as f:
1497
+ json.dump(descriptions, f, indent=4)
1498
+
1499
+ print("Processing completed!")
1500
+
1501
+
1502
+ # def process_cirr_images():
1503
+ # # Define paths
1504
+ # dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
1505
+ # cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
1506
+ # output_dirs = {
1507
+ # "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_edited"),
1508
+ # "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_mask"),
1509
+ # "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_masked")
1510
+ # }
1511
+
1512
+ # # Create output directories if they don't exist
1513
+ # for dir_path in output_dirs.values():
1514
+ # dir_path.mkdir(parents=True, exist_ok=True)
1515
+
1516
+ # # Load captions from JSON file
1517
+ # with open(cap_file, 'r') as f:
1518
+ # captions = json.load(f)
1519
+
1520
+ # # Initialize description dictionary
1521
+ # descriptions = {}
1522
+
1523
+ # # Process each PNG image in dev directory
1524
+ # for img_path in dev_dir.glob("*.png"):
1525
+ # # Get base name without extension
1526
+ # base_name = img_path.stem
1527
+
1528
+ # # Find matching caption
1529
+ # caption = None
1530
+ # for item in captions:
1531
+ # if item.get("reference") == base_name:
1532
+ # caption = item.get("caption")
1533
+ # break
1534
+
1535
+ # if caption is None:
1536
+ # print(f"Warning: No caption found for {base_name}")
1537
+ # continue
1538
+
1539
+ # # Load and convert image to RGB
1540
+ # try:
1541
+ # rgb_image = Image.open(img_path).convert("RGB")
1542
+ # a = Image.new("L", rgb_image.size, 255) # 全不透明alpha通道
1543
+ # image = Image.merge("RGBA", (*rgb_image.split(), a))
1544
+ # except Exception as e:
1545
+ # print(f"Error loading image {img_path}: {e}")
1546
+ # continue
1547
+
1548
+ # # Generate BLIP description
1549
+ # try:
1550
+ # blip_desc, _ = generate_blip_description({"background": image})
1551
+ # except Exception as e:
1552
+ # print(f"Error generating BLIP description for {base_name}: {e}")
1553
+ # continue
1554
+
1555
+ # # Process image
1556
+ # try:
1557
+ # # Prepare input parameters for process function
1558
+ # input_image = {"background": image, "layers": [image]}
1559
+ # original_image = np.array(image)
1560
+ # original_mask = None
1561
+ # prompt = caption
1562
+ # negative_prompt = "ugly, low quality"
1563
+ # control_strength = 1.0
1564
+ # seed = 648464818
1565
+ # randomize_seed = False
1566
+ # guidance_scale = 7.5
1567
+ # num_inference_steps = 50
1568
+ # num_samples = 1
1569
+ # blending = True
1570
+ # category = None
1571
+ # target_prompt = ""
1572
+ # resize_default = True
1573
+ # aspect_ratio = "Custom resolution"
1574
+ # invert_mask_state = False
1575
+
1576
+ # # Call process function and handle return values properly
1577
+ # process_results = process(
1578
+ # input_image,
1579
+ # original_image,
1580
+ # original_mask,
1581
+ # prompt,
1582
+ # negative_prompt,
1583
+ # control_strength,
1584
+ # seed,
1585
+ # randomize_seed,
1586
+ # guidance_scale,
1587
+ # num_inference_steps,
1588
+ # num_samples,
1589
+ # blending,
1590
+ # category,
1591
+ # target_prompt,
1592
+ # resize_default,
1593
+ # aspect_ratio,
1594
+ # invert_mask_state
1595
+ # )
1596
+
1597
+ # # Extract results safely
1598
+ # result_images = process_results[0]
1599
+ # mask_images = process_results[1]
1600
+ # masked_images = process_results[2]
1601
+
1602
+ # # Ensure we have valid images to save
1603
+ # if not result_images or not mask_images or not masked_images:
1604
+ # print(f"Warning: No output images generated for {base_name}")
1605
+ # continue
1606
+
1607
+ # # Save processed images
1608
+ # # Save edited image
1609
+ # edited_path = output_dirs["edited"] / f"{base_name}.png"
1610
+ # if isinstance(result_images, (list, tuple)):
1611
+ # result_images[0].save(edited_path)
1612
+ # else:
1613
+ # result_images.save(edited_path)
1614
+
1615
+ # # Save mask image
1616
+ # mask_path = output_dirs["mask"] / f"{base_name}_mask.png"
1617
+ # if isinstance(mask_images, (list, tuple)):
1618
+ # mask_images[0].save(mask_path)
1619
+ # else:
1620
+ # mask_images.save(mask_path)
1621
+
1622
+ # # Save masked image
1623
+ # masked_path = output_dirs["masked"] / f"{base_name}_masked.png"
1624
+ # if isinstance(masked_images, (list, tuple)):
1625
+ # masked_images[0].save(masked_path)
1626
+ # else:
1627
+ # masked_images.save(masked_path)
1628
+
1629
+ # # Store description
1630
+ # descriptions[base_name] = {
1631
+ # "original_caption": caption,
1632
+ # "blip_description": blip_desc
1633
+ # }
1634
+
1635
+ # print(f"Successfully processed {base_name}")
1636
+
1637
+ # except Exception as e:
1638
+ # print(f"Error processing image {base_name}: {e}")
1639
+ # continue
1640
+
1641
+ # # Save descriptions to JSON file
1642
+ # with open("/home/zt/data/BrushEdit/cirr/cirr_description_fix.json", 'w') as f:
1643
+ # json.dump(descriptions, f, indent=4)
1644
+
1645
+ # print("Processing completed!")
1646
+
1647
+
1648
+ def generate_blip_description(input_image):
1649
+ if input_image is None:
1650
+ return "", "Input image cannot be None"
1651
+ try:
1652
+ image_pil = input_image["background"].convert("RGB")
1653
+ except KeyError:
1654
+ return "", "Input image missing 'background' key"
1655
+ except AttributeError as e:
1656
+ return "", f"Invalid image object: {str(e)}"
1657
+ try:
1658
+ description = generate_caption(blip_processor, blip_model, image_pil, device)
1659
+ return description, description # 同时更新state和显示组件
1660
+ except Exception as e:
1661
+ return "", f"Caption generation failed: {str(e)}"
1662
+
1663
+
1664
+ def submit_GPT4o_KEY(GPT4o_KEY):
1665
+ global vlm_model, vlm_processor
1666
+ if vlm_model is not None:
1667
+ del vlm_model
1668
+ torch.cuda.empty_cache()
1669
+ try:
1670
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
1671
+ vlm_processor = ""
1672
+ response = vlm_model.chat.completions.create(
1673
+ model="deepseek-chat",
1674
+ messages=[
1675
+ {"role": "system", "content": "You are a helpful assistant."},
1676
+ {"role": "user", "content": "Hello."}
1677
+ ]
1678
+ )
1679
+ response_str = response.choices[0].message.content
1680
+
1681
+ return "Success. " + response_str, "GPT4-o (Highly Recommended)"
1682
+ except Exception as e:
1683
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
1684
+
1685
+
1686
+ def verify_deepseek_api():
1687
+ try:
1688
+ response = llm_model.chat.completions.create(
1689
+ model="deepseek-chat",
1690
+ messages=[
1691
+ {"role": "system", "content": "You are a helpful assistant."},
1692
+ {"role": "user", "content": "Hello."}
1693
+ ]
1694
+ )
1695
+ response_str = response.choices[0].message.content
1696
+
1697
+ return True, "Success. " + response_str
1698
+
1699
+ except Exception as e:
1700
+ return False, "Invalid DeepSeek API Key"
1701
+
1702
+
1703
+ def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
1704
+ try:
1705
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
1706
+ response = llm_model.chat.completions.create(
1707
+ model="deepseek-chat",
1708
+ messages=messages
1709
+ )
1710
+ response_str = response.choices[0].message.content
1711
+ return response_str
1712
+ except Exception as e:
1713
+ raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1714
+
1715
+
1716
+ def llm_decomposed_prompt_after_apply_instruction(integrated_query):
1717
+ try:
1718
+ messages = create_decomposed_query_messages_deepseek(integrated_query)
1719
+ response = llm_model.chat.completions.create(
1720
+ model="deepseek-chat",
1721
+ messages=messages
1722
+ )
1723
+ response_str = response.choices[0].message.content
1724
+ return response_str
1725
+ except Exception as e:
1726
+ raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1727
+
1728
+
1729
+ def enhance_description(blip_description, prompt):
1730
+ try:
1731
+ if not prompt or not blip_description:
1732
+ print("Empty prompt or blip_description detected")
1733
+ return "", ""
1734
+
1735
+ print(f"Enhancing with prompt: {prompt}")
1736
+ enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
1737
+ return enhanced_description, enhanced_description
1738
+
1739
+ except Exception as e:
1740
+ print(f"Enhancement failed: {str(e)}")
1741
+ return "Error occurred", "Error occurred"
1742
+
1743
+
1744
+ def decompose_description(enhanced_description):
1745
+ try:
1746
+ if not enhanced_description:
1747
+ print("Empty enhanced_description detected")
1748
+ return "", ""
1749
+
1750
+ print(f"Decomposing the enhanced description: {enhanced_description}")
1751
+ decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
1752
+ return decomposed_description, decomposed_description
1753
+
1754
+ except Exception as e:
1755
+ print(f"Decomposition failed: {str(e)}")
1756
+ return "Error occurred", "Error occurred"
1757
+
1758
+
1759
+ @torch.no_grad()
1760
+ def mix_and_search(enhanced_text: str, gallery_images: list):
1761
+ # 获取最新生成的图像元组
1762
+ latest_item = gallery_images[-1] if gallery_images else None
1763
+
1764
+ # 初始化特征列表
1765
+ features = []
1766
+
1767
+ # 图像特征提取
1768
+ if latest_item and isinstance(latest_item, tuple):
1769
+ try:
1770
+ image_path = latest_item[0]
1771
+ pil_image = Image.open(image_path).convert("RGB")
1772
+
1773
+ # 使用 CLIPProcessor 处理图像
1774
+ image_inputs = clip_processor(
1775
+ images=pil_image,
1776
+ return_tensors="pt"
1777
+ ).to(device)
1778
+
1779
+ image_features = clip_model.get_image_features(**image_inputs)
1780
+ features.append(F.normalize(image_features, dim=-1))
1781
+ except Exception as e:
1782
+ print(f"图像处理失败: {str(e)}")
1783
+
1784
+ # 文本特征提取
1785
+ if enhanced_text.strip():
1786
+ text_inputs = clip_processor(
1787
+ text=enhanced_text,
1788
+ return_tensors="pt",
1789
+ padding=True,
1790
+ truncation=True
1791
+ ).to(device)
1792
+
1793
+ text_features = clip_model.get_text_features(**text_inputs)
1794
+ features.append(F.normalize(text_features, dim=-1))
1795
+
1796
+ if not features:
1797
+ return []
1798
+
1799
+
1800
+ # 特征融合与检索
1801
+ mixed = sum(features) / len(features)
1802
+ mixed = F.normalize(mixed, dim=-1)
1803
+
1804
+ # 加载Faiss索引和图片路径映射
1805
+ # index_path = "/home/zt/data/open-images/train/knn.index"
1806
+ # input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
1807
+ # base_image_dir = Path("/home/zt/data/open-images/train/")
1808
+
1809
+ index_path = "/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index"
1810
+ input_data_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata")
1811
+ base_image_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/")
1812
+
1813
+ # 按文件名中的数字排序并直接读取parquet文件
1814
+ parquet_files = sorted(
1815
+ input_data_dir.glob('*.parquet'),
1816
+ key=lambda x: int(x.stem.split("_")[-1])
1817
+ )
1818
+
1819
+ # 合并所有parquet数据
1820
+ dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
1821
+ df = pd.concat(dfs, ignore_index=True)
1822
+ image_paths = df["image_path"].tolist()
1823
+
1824
+ # 读取Faiss索引
1825
+ index = faiss.read_index(index_path)
1826
+ assert mixed.shape[1] == index.d, "特征维度不匹配"
1827
+
1828
+ # 执行检索
1829
+ mixed = mixed.cpu().detach().numpy().astype('float32')
1830
+ distances, indices = index.search(mixed, 50)
1831
+
1832
+ # 获取并验证图片路径
1833
+ retrieved_images = []
1834
+ for idx in indices[0]:
1835
+ if 0 <= idx < len(image_paths):
1836
+ img_path = base_image_dir / image_paths[idx]
1837
+ try:
1838
+ if img_path.exists():
1839
+ retrieved_images.append(Image.open(img_path).convert("RGB"))
1840
+ else:
1841
+ print(f"警告:文件缺失 {img_path}")
1842
+ except Exception as e:
1843
+ print(f"图片加载失败: {str(e)}")
1844
+
1845
+ return retrieved_images if retrieved_images else ([])
1846
+
1847
+
1848
+
1849
+ if __name__ == "__main__":
1850
+ process_cirr_images()
1851
+
1852
+
1853
+ # block = gr.Blocks(
1854
+ # theme=gr.themes.Soft(
1855
+ # radius_size=gr.themes.sizes.radius_none,
1856
+ # text_size=gr.themes.sizes.text_md
1857
+ # )
1858
+ # )
1859
+
1860
+ # with block as demo:
1861
+ # with gr.Row():
1862
+ # with gr.Column():
1863
+ # gr.HTML(head)
1864
+ # gr.Markdown(descriptions)
1865
+ # with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1866
+ # with gr.Row(equal_height=True):
1867
+ # gr.Markdown(instructions)
1868
+
1869
+ # original_image = gr.State(value=None)
1870
+ # original_mask = gr.State(value=None)
1871
+ # category = gr.State(value=None)
1872
+ # status = gr.State(value=None)
1873
+ # invert_mask_state = gr.State(value=False)
1874
+ # example_change_times = gr.State(value=0)
1875
+ # deepseek_verified = gr.State(value=False)
1876
+ # blip_description = gr.State(value="")
1877
+ # enhanced_description = gr.State(value="")
1878
+ # decomposed_description = gr.State(value="")
1879
+
1880
+ # with gr.Row():
1881
+ # with gr.Column():
1882
+ # with gr.Group():
1883
+ # input_image = gr.ImageEditor(
1884
+ # label="参考图像",
1885
+ # type="pil",
1886
+ # brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1887
+ # layers = False,
1888
+ # interactive=True,
1889
+ # # height=1024,
1890
+ # height=412,
1891
+ # sources=["upload"],
1892
+ # placeholder="🫧 点击此处或下面的图标上传图像 🫧",
1893
+ # )
1894
+ # prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
1895
+
1896
+ # with gr.Group():
1897
+ # mask_button = gr.Button("💎 掩膜生成")
1898
+ # with gr.Row():
1899
+ # invert_mask_button = gr.Button("👐 掩膜翻转")
1900
+ # random_mask_button = gr.Button("⭕️ 随机掩膜")
1901
+ # with gr.Row():
1902
+ # masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
1903
+ # mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
1904
+
1905
+
1906
+ # with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
1907
+ # dilation_size = gr.Slider(
1908
+ # label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
1909
+ # )
1910
+ # with gr.Row():
1911
+ # dilation_mask_button = gr.Button("放大掩膜")
1912
+ # erosion_mask_button = gr.Button("缩小掩膜")
1913
+
1914
+ # moving_pixels = gr.Slider(
1915
+ # label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
1916
+ # )
1917
+ # with gr.Row():
1918
+ # move_left_button = gr.Button("左移")
1919
+ # move_right_button = gr.Button("右移")
1920
+ # with gr.Row():
1921
+ # move_up_button = gr.Button("上移")
1922
+ # move_down_button = gr.Button("下移")
1923
+
1924
+
1925
+
1926
+ # with gr.Column():
1927
+ # with gr.Row():
1928
+ # deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=2, container=False)
1929
+ # verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
1930
+ # blip_output = gr.Textbox(label="1. 原图描述(BLIP生成)", placeholder="🖼️ 上传图片后自动生成图片描述 🖼️", lines=2, interactive=True)
1931
+ # with gr.Row():
1932
+ # enhanced_output = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击右侧按钮生成增强描述 🚀")
1933
+ # enhance_button = gr.Button("✨ 智能整合")
1934
+
1935
+ # with gr.Row():
1936
+ # decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述 📝")
1937
+ # decompose_button = gr.Button("🔧 结构分解")
1938
+
1939
+
1940
+
1941
+ # with gr.Group():
1942
+ # run_button = gr.Button("💫 图像编辑")
1943
+ # result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
1944
+
1945
+ # with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
1946
+ # vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1947
+
1948
+ # with gr.Group():
1949
+ # with gr.Row():
1950
+ # # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1951
+ # GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1952
+ # GPT4o_KEY_submit = gr.Button("🔑 验证密钥")
1953
+
1954
+ # aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1955
+ # resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
1956
+ # base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1957
+ # negative_prompt = gr.Text(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
1958
+ # control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
1959
+ # with gr.Group():
1960
+ # seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
1961
+ # randomize_seed = gr.Checkbox(label="随机种子", value=False)
1962
+ # blending = gr.Checkbox(label="混合模式", value=True)
1963
+ # num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
1964
+ # with gr.Group():
1965
+ # with gr.Row():
1966
+ # guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
1967
+ # num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
1968
+ # target_prompt = gr.Text(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
1969
+
1970
+
1971
+
1972
+ # init_type = gr.Textbox(label="Init Name", value="", visible=False)
1973
+ # example_type = gr.Textbox(label="Example Name", value="", visible=False)
1974
+
1975
+ # with gr.Row():
1976
+ # reset_button = gr.Button("Reset")
1977
+ # retrieve_button = gr.Button("🔍 开始检索")
1978
+
1979
+ # with gr.Row():
1980
+ # retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=660)
1981
+
1982
+
1983
+ # with gr.Row():
1984
+ # example = gr.Examples(
1985
+ # label="Quick Example",
1986
+ # examples=EXAMPLES,
1987
+ # inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1988
+ # examples_per_page=10,
1989
+ # cache_examples=False,
1990
+ # )
1991
+
1992
+
1993
+ # with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1994
+ # with gr.Row(equal_height=True):
1995
+ # gr.Markdown(tips)
1996
+
1997
+ # with gr.Row():
1998
+ # gr.Markdown(citation)
1999
+
2000
+ # ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
2001
+ # ## And we need to solve the conflict between the upload and change example functions.
2002
+ # input_image.upload(
2003
+ # init_img,
2004
+ # [input_image, init_type, prompt, aspect_ratio, example_change_times],
2005
+ # [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
2006
+
2007
+ # )
2008
+ # example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
2009
+
2010
+
2011
+ # ## vlm and base model dropdown
2012
+ # vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
2013
+ # base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
2014
+
2015
+ # GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
2016
+
2017
+
2018
+ # ips=[input_image,
2019
+ # original_image,
2020
+ # original_mask,
2021
+ # prompt,
2022
+ # negative_prompt,
2023
+ # control_strength,
2024
+ # seed,
2025
+ # randomize_seed,
2026
+ # guidance_scale,
2027
+ # num_inference_steps,
2028
+ # num_samples,
2029
+ # blending,
2030
+ # category,
2031
+ # target_prompt,
2032
+ # resize_default,
2033
+ # aspect_ratio,
2034
+ # invert_mask_state]
2035
+
2036
+ # ## run brushedit
2037
+ # run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
2038
+
2039
+
2040
+ # ## mask func
2041
+ # mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
2042
+ # random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
2043
+ # dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
2044
+ # erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
2045
+ # invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
2046
+
2047
+ # ## reset func
2048
+ # reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
2049
+
2050
+ # input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
2051
+ # verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
2052
+ # enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
2053
+ # decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
2054
+ # retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
2055
+
2056
+ # demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
2057
+
2058
+
brushedit_app_new_aftermeeting_nocirr.py ADDED
@@ -0,0 +1,1809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+ from pathlib import Path
9
+ import pandas as pd
10
+ import concurrent.futures
11
+ import faiss
12
+ import gradio as gr
13
+
14
+ from PIL import Image
15
+
16
+ import torch.nn.functional as F # 新增此行
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+ from scipy.ndimage import binary_dilation, binary_erosion
19
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
20
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
21
+
22
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
23
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
24
+ from diffusers.image_processor import VaeImageProcessor
25
+
26
+
27
+ from app.src.vlm_pipeline import (
28
+ vlm_response_editing_type,
29
+ vlm_response_object_wait_for_edit,
30
+ vlm_response_mask,
31
+ vlm_response_prompt_after_apply_instruction
32
+ )
33
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
34
+ from app.utils.utils import load_grounding_dino_model
35
+
36
+ from app.src.vlm_template import vlms_template
37
+ from app.src.base_model_template import base_models_template
38
+ from app.src.aspect_ratio_template import aspect_ratios
39
+
40
+ from openai import OpenAI
41
+ base_openai_url = "https://api.deepseek.com/"
42
+ base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
43
+
44
+
45
+ from transformers import BlipProcessor, BlipForConditionalGeneration
46
+ from transformers import CLIPProcessor, CLIPModel
47
+
48
+ from app.deepseek.instructions import (
49
+ create_apply_editing_messages_deepseek,
50
+ create_decomposed_query_messages_deepseek
51
+ )
52
+ from clip_retrieval.clip_client import ClipClient
53
+
54
+ #### Description ####
55
+ logo = r"""
56
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
57
+ """
58
+ head = r"""
59
+ <div style="text-align: center;">
60
+ <h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
61
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
62
+ <a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
63
+ <a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
64
+ <a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
65
+
66
+ </div>
67
+ </br>
68
+ </div>
69
+ """
70
+ descriptions = r"""
71
+ Demo for ZS-CIR"""
72
+
73
+ instructions = r"""
74
+ Demo for ZS-CIR"""
75
+
76
+ tips = r"""
77
+ Demo for ZS-CIR
78
+
79
+ """
80
+
81
+
82
+
83
+ citation = r"""
84
+ Demo for ZS-CIR"""
85
+
86
+ # - - - - - examples - - - - - #
87
+ EXAMPLES = [
88
+
89
+ [
90
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
91
+ "add a magic hat on frog head.",
92
+ 642087011,
93
+ "frog",
94
+ "frog",
95
+ True,
96
+ False,
97
+ "GPT4-o (Highly Recommended)"
98
+ ],
99
+ [
100
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
101
+ "replace the background to ancient China.",
102
+ 648464818,
103
+ "chinese_girl",
104
+ "chinese_girl",
105
+ True,
106
+ False,
107
+ "GPT4-o (Highly Recommended)"
108
+ ],
109
+ [
110
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
111
+ "remove the deer.",
112
+ 648464818,
113
+ "angel_christmas",
114
+ "angel_christmas",
115
+ False,
116
+ False,
117
+ "GPT4-o (Highly Recommended)"
118
+ ],
119
+ [
120
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
121
+ "add a wreath on head.",
122
+ 648464818,
123
+ "sunflower_girl",
124
+ "sunflower_girl",
125
+ True,
126
+ False,
127
+ "GPT4-o (Highly Recommended)"
128
+ ],
129
+ [
130
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
131
+ "add a butterfly fairy.",
132
+ 648464818,
133
+ "girl_on_sun",
134
+ "girl_on_sun",
135
+ True,
136
+ False,
137
+ "GPT4-o (Highly Recommended)"
138
+ ],
139
+ [
140
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
141
+ "remove the christmas hat.",
142
+ 642087011,
143
+ "spider_man_rm",
144
+ "spider_man_rm",
145
+ False,
146
+ False,
147
+ "GPT4-o (Highly Recommended)"
148
+ ],
149
+ [
150
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
151
+ "remove the flower.",
152
+ 642087011,
153
+ "anime_flower",
154
+ "anime_flower",
155
+ False,
156
+ False,
157
+ "GPT4-o (Highly Recommended)"
158
+ ],
159
+ [
160
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
161
+ "replace the clothes to a delicated floral skirt.",
162
+ 648464818,
163
+ "chenduling",
164
+ "chenduling",
165
+ True,
166
+ False,
167
+ "GPT4-o (Highly Recommended)"
168
+ ],
169
+ [
170
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
171
+ "make the hedgehog in Italy.",
172
+ 648464818,
173
+ "hedgehog_rp_bg",
174
+ "hedgehog_rp_bg",
175
+ True,
176
+ False,
177
+ "GPT4-o (Highly Recommended)"
178
+ ],
179
+
180
+ ]
181
+
182
+ INPUT_IMAGE_PATH = {
183
+ "frog": "./assets/frog/frog.jpeg",
184
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
185
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
186
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
187
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
188
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
189
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
190
+ "chenduling": "./assets/chenduling/chengduling.jpg",
191
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
192
+ }
193
+ MASK_IMAGE_PATH = {
194
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
195
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
196
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
197
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
198
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
199
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
200
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
201
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
202
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
203
+ }
204
+ MASKED_IMAGE_PATH = {
205
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
206
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
207
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
208
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
209
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
210
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
211
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
212
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
213
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
214
+ }
215
+ OUTPUT_IMAGE_PATH = {
216
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
217
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
218
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
219
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
220
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
221
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
222
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
223
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
224
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
225
+ }
226
+
227
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
228
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
229
+
230
+ VLM_MODEL_NAMES = list(vlms_template.keys())
231
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
232
+
233
+
234
+ BASE_MODELS = list(base_models_template.keys())
235
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
236
+
237
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
238
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
239
+
240
+
241
+ ## init device
242
+ try:
243
+ if torch.cuda.is_available():
244
+ device = "cuda"
245
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
246
+ device = "mps"
247
+ else:
248
+ device = "cpu"
249
+ except:
250
+ device = "cpu"
251
+
252
+ # ## init torch dtype
253
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
254
+ # torch_dtype = torch.bfloat16
255
+ # else:
256
+ # torch_dtype = torch.float16
257
+
258
+ # if device == "mps":
259
+ # torch_dtype = torch.float16
260
+
261
+ torch_dtype = torch.float16
262
+
263
+
264
+
265
+ # download hf models
266
+ BrushEdit_path = "models/"
267
+ if not os.path.exists(BrushEdit_path):
268
+ BrushEdit_path = snapshot_download(
269
+ repo_id="TencentARC/BrushEdit",
270
+ local_dir=BrushEdit_path,
271
+ token=os.getenv("HF_TOKEN"),
272
+ )
273
+
274
+ ## init default VLM
275
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
276
+ if vlm_processor != "" and vlm_model != "":
277
+ vlm_model.to(device)
278
+ else:
279
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
280
+
281
+ ## init default LLM
282
+ llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
283
+
284
+ ## init base model
285
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
286
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
287
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
288
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
289
+
290
+
291
+ # input brushnetX ckpt path
292
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
293
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
294
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
295
+ )
296
+ # speed up diffusion process with faster scheduler and memory optimization
297
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
298
+ # remove following line if xformers is not installed or when using Torch 2.0.
299
+ # pipe.enable_xformers_memory_efficient_attention()
300
+ pipe.enable_model_cpu_offload()
301
+
302
+
303
+ ## init SAM
304
+ sam = build_sam(checkpoint=sam_path)
305
+ sam.to(device=device)
306
+ sam_predictor = SamPredictor(sam)
307
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
308
+
309
+ ## init groundingdino_model
310
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
311
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
312
+
313
+ ## Ordinary function
314
+ def crop_and_resize(image: Image.Image,
315
+ target_width: int,
316
+ target_height: int) -> Image.Image:
317
+ """
318
+ Crops and resizes an image while preserving the aspect ratio.
319
+
320
+ Args:
321
+ image (Image.Image): Input PIL image to be cropped and resized.
322
+ target_width (int): Target width of the output image.
323
+ target_height (int): Target height of the output image.
324
+
325
+ Returns:
326
+ Image.Image: Cropped and resized image.
327
+ """
328
+ # Original dimensions
329
+ original_width, original_height = image.size
330
+ original_aspect = original_width / original_height
331
+ target_aspect = target_width / target_height
332
+
333
+ # Calculate crop box to maintain aspect ratio
334
+ if original_aspect > target_aspect:
335
+ # Crop horizontally
336
+ new_width = int(original_height * target_aspect)
337
+ new_height = original_height
338
+ left = (original_width - new_width) / 2
339
+ top = 0
340
+ right = left + new_width
341
+ bottom = original_height
342
+ else:
343
+ # Crop vertically
344
+ new_width = original_width
345
+ new_height = int(original_width / target_aspect)
346
+ left = 0
347
+ top = (original_height - new_height) / 2
348
+ right = original_width
349
+ bottom = top + new_height
350
+
351
+ # Crop and resize
352
+ cropped_image = image.crop((left, top, right, bottom))
353
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
354
+ return resized_image
355
+
356
+
357
+ ## Ordinary function
358
+ def resize(image: Image.Image,
359
+ target_width: int,
360
+ target_height: int) -> Image.Image:
361
+ """
362
+ Crops and resizes an image while preserving the aspect ratio.
363
+
364
+ Args:
365
+ image (Image.Image): Input PIL image to be cropped and resized.
366
+ target_width (int): Target width of the output image.
367
+ target_height (int): Target height of the output image.
368
+
369
+ Returns:
370
+ Image.Image: Cropped and resized image.
371
+ """
372
+ # Original dimensions
373
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
374
+ return resized_image
375
+
376
+
377
+ def move_mask_func(mask, direction, units):
378
+ binary_mask = mask.squeeze()>0
379
+ rows, cols = binary_mask.shape
380
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
381
+
382
+ if direction == 'down':
383
+ # move down
384
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
385
+
386
+ elif direction == 'up':
387
+ # move up
388
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
389
+
390
+ elif direction == 'right':
391
+ # move left
392
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
393
+
394
+ elif direction == 'left':
395
+ # move right
396
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
397
+
398
+ return moved_mask
399
+
400
+
401
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
402
+ # Randomly select the size of dilation
403
+ binary_mask = mask.squeeze()>0
404
+
405
+ if dilation_type == 'square_dilation':
406
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
407
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
408
+ elif dilation_type == 'square_erosion':
409
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
410
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
411
+ elif dilation_type == 'bounding_box':
412
+ # find the most left top and left bottom point
413
+ rows, cols = np.where(binary_mask)
414
+ if len(rows) == 0 or len(cols) == 0:
415
+ return mask # return original mask if no valid points
416
+
417
+ min_row = np.min(rows)
418
+ max_row = np.max(rows)
419
+ min_col = np.min(cols)
420
+ max_col = np.max(cols)
421
+
422
+ # create a bounding box
423
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
424
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
425
+
426
+ elif dilation_type == 'bounding_ellipse':
427
+ # find the most left top and left bottom point
428
+ rows, cols = np.where(binary_mask)
429
+ if len(rows) == 0 or len(cols) == 0:
430
+ return mask # return original mask if no valid points
431
+
432
+ min_row = np.min(rows)
433
+ max_row = np.max(rows)
434
+ min_col = np.min(cols)
435
+ max_col = np.max(cols)
436
+
437
+ # calculate the center and axis length of the ellipse
438
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
439
+ a = (max_col - min_col) // 2 # half long axis
440
+ b = (max_row - min_row) // 2 # half short axis
441
+
442
+ # create a bounding ellipse
443
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
444
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
445
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
446
+ dilated_mask[ellipse_mask] = True
447
+ else:
448
+ ValueError("dilation_type must be 'square' or 'ellipse'")
449
+
450
+ # use binary dilation
451
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
452
+ return dilated_mask
453
+
454
+
455
+ ## Gradio component function
456
+ def update_vlm_model(vlm_name):
457
+ global vlm_model, vlm_processor
458
+ if vlm_model is not None:
459
+ del vlm_model
460
+ torch.cuda.empty_cache()
461
+
462
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
463
+
464
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
465
+ if vlm_type == "llava-next":
466
+ if vlm_processor != "" and vlm_model != "":
467
+ vlm_model.to(device)
468
+ return vlm_model_dropdown
469
+ else:
470
+ if os.path.exists(vlm_local_path):
471
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
472
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
473
+ else:
474
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
475
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
476
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
477
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
478
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
479
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
480
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
481
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
482
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
483
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
484
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
485
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
486
+ elif vlm_name == "llava-next-72b-hf (Preload)":
487
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
488
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
489
+ elif vlm_type == "qwen2-vl":
490
+ if vlm_processor != "" and vlm_model != "":
491
+ vlm_model.to(device)
492
+ return vlm_model_dropdown
493
+ else:
494
+ if os.path.exists(vlm_local_path):
495
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
496
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
497
+ else:
498
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
499
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
500
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
501
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
502
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
503
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
504
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
505
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
506
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
507
+ elif vlm_type == "openai":
508
+ pass
509
+ return "success"
510
+
511
+
512
+ def update_base_model(base_model_name):
513
+ global pipe
514
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
515
+ if pipe is not None:
516
+ del pipe
517
+ torch.cuda.empty_cache()
518
+ base_model_path, pipe = base_models_template[base_model_name]
519
+ if pipe != "":
520
+ pipe.to(device)
521
+ else:
522
+ if os.path.exists(base_model_path):
523
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
524
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
525
+ )
526
+ # pipe.enable_xformers_memory_efficient_attention()
527
+ pipe.enable_model_cpu_offload()
528
+ else:
529
+ raise gr.Error(f"The base model {base_model_name} does not exist")
530
+ return "success"
531
+
532
+
533
+ def process_random_mask(input_image,
534
+ original_image,
535
+ original_mask,
536
+ resize_default,
537
+ aspect_ratio_name,
538
+ ):
539
+
540
+ alpha_mask = input_image["layers"][0].split()[3]
541
+ input_mask = np.asarray(alpha_mask)
542
+
543
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
544
+ if output_w == "" or output_h == "":
545
+ output_h, output_w = original_image.shape[:2]
546
+ if resize_default:
547
+ short_side = min(output_w, output_h)
548
+ scale_ratio = 640 / short_side
549
+ output_w = int(output_w * scale_ratio)
550
+ output_h = int(output_h * scale_ratio)
551
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
552
+ original_image = np.array(original_image)
553
+ if input_mask is not None:
554
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
555
+ input_mask = np.array(input_mask)
556
+ if original_mask is not None:
557
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
558
+ original_mask = np.array(original_mask)
559
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
560
+ else:
561
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
562
+ pass
563
+ else:
564
+ if resize_default:
565
+ short_side = min(output_w, output_h)
566
+ scale_ratio = 640 / short_side
567
+ output_w = int(output_w * scale_ratio)
568
+ output_h = int(output_h * scale_ratio)
569
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
570
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
571
+ original_image = np.array(original_image)
572
+ if input_mask is not None:
573
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
574
+ input_mask = np.array(input_mask)
575
+ if original_mask is not None:
576
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
577
+ original_mask = np.array(original_mask)
578
+
579
+
580
+ if input_mask.max() == 0:
581
+ original_mask = original_mask
582
+ else:
583
+ original_mask = input_mask
584
+
585
+ if original_mask is None:
586
+ raise gr.Error('Please generate mask first')
587
+
588
+ if original_mask.ndim == 2:
589
+ original_mask = original_mask[:,:,None]
590
+
591
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
592
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
593
+
594
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
595
+
596
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
597
+ masked_image = masked_image.astype(original_image.dtype)
598
+ masked_image = Image.fromarray(masked_image)
599
+
600
+
601
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
602
+
603
+
604
+ def process_dilation_mask(input_image,
605
+ original_image,
606
+ original_mask,
607
+ resize_default,
608
+ aspect_ratio_name,
609
+ dilation_size=20):
610
+
611
+ alpha_mask = input_image["layers"][0].split()[3]
612
+ input_mask = np.asarray(alpha_mask)
613
+
614
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
615
+ if output_w == "" or output_h == "":
616
+ output_h, output_w = original_image.shape[:2]
617
+ if resize_default:
618
+ short_side = min(output_w, output_h)
619
+ scale_ratio = 640 / short_side
620
+ output_w = int(output_w * scale_ratio)
621
+ output_h = int(output_h * scale_ratio)
622
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
623
+ original_image = np.array(original_image)
624
+ if input_mask is not None:
625
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
626
+ input_mask = np.array(input_mask)
627
+ if original_mask is not None:
628
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
629
+ original_mask = np.array(original_mask)
630
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
631
+ else:
632
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
633
+ pass
634
+ else:
635
+ if resize_default:
636
+ short_side = min(output_w, output_h)
637
+ scale_ratio = 640 / short_side
638
+ output_w = int(output_w * scale_ratio)
639
+ output_h = int(output_h * scale_ratio)
640
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
641
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
642
+ original_image = np.array(original_image)
643
+ if input_mask is not None:
644
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
645
+ input_mask = np.array(input_mask)
646
+ if original_mask is not None:
647
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
648
+ original_mask = np.array(original_mask)
649
+
650
+ if input_mask.max() == 0:
651
+ original_mask = original_mask
652
+ else:
653
+ original_mask = input_mask
654
+
655
+ if original_mask is None:
656
+ raise gr.Error('Please generate mask first')
657
+
658
+ if original_mask.ndim == 2:
659
+ original_mask = original_mask[:,:,None]
660
+
661
+ dilation_type = np.random.choice(['square_dilation'])
662
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
663
+
664
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
665
+
666
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
667
+ masked_image = masked_image.astype(original_image.dtype)
668
+ masked_image = Image.fromarray(masked_image)
669
+
670
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
671
+
672
+
673
+ def process_erosion_mask(input_image,
674
+ original_image,
675
+ original_mask,
676
+ resize_default,
677
+ aspect_ratio_name,
678
+ dilation_size=20):
679
+ alpha_mask = input_image["layers"][0].split()[3]
680
+ input_mask = np.asarray(alpha_mask)
681
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
682
+ if output_w == "" or output_h == "":
683
+ output_h, output_w = original_image.shape[:2]
684
+ if resize_default:
685
+ short_side = min(output_w, output_h)
686
+ scale_ratio = 640 / short_side
687
+ output_w = int(output_w * scale_ratio)
688
+ output_h = int(output_h * scale_ratio)
689
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
690
+ original_image = np.array(original_image)
691
+ if input_mask is not None:
692
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
693
+ input_mask = np.array(input_mask)
694
+ if original_mask is not None:
695
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
696
+ original_mask = np.array(original_mask)
697
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
698
+ else:
699
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
700
+ pass
701
+ else:
702
+ if resize_default:
703
+ short_side = min(output_w, output_h)
704
+ scale_ratio = 640 / short_side
705
+ output_w = int(output_w * scale_ratio)
706
+ output_h = int(output_h * scale_ratio)
707
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
708
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
709
+ original_image = np.array(original_image)
710
+ if input_mask is not None:
711
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
712
+ input_mask = np.array(input_mask)
713
+ if original_mask is not None:
714
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
715
+ original_mask = np.array(original_mask)
716
+
717
+ if input_mask.max() == 0:
718
+ original_mask = original_mask
719
+ else:
720
+ original_mask = input_mask
721
+
722
+ if original_mask is None:
723
+ raise gr.Error('Please generate mask first')
724
+
725
+ if original_mask.ndim == 2:
726
+ original_mask = original_mask[:,:,None]
727
+
728
+ dilation_type = np.random.choice(['square_erosion'])
729
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
730
+
731
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
732
+
733
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
734
+ masked_image = masked_image.astype(original_image.dtype)
735
+ masked_image = Image.fromarray(masked_image)
736
+
737
+
738
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
739
+
740
+
741
+ def move_mask_left(input_image,
742
+ original_image,
743
+ original_mask,
744
+ moving_pixels,
745
+ resize_default,
746
+ aspect_ratio_name):
747
+
748
+ alpha_mask = input_image["layers"][0].split()[3]
749
+ input_mask = np.asarray(alpha_mask)
750
+
751
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
752
+ if output_w == "" or output_h == "":
753
+ output_h, output_w = original_image.shape[:2]
754
+ if resize_default:
755
+ short_side = min(output_w, output_h)
756
+ scale_ratio = 640 / short_side
757
+ output_w = int(output_w * scale_ratio)
758
+ output_h = int(output_h * scale_ratio)
759
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
760
+ original_image = np.array(original_image)
761
+ if input_mask is not None:
762
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
763
+ input_mask = np.array(input_mask)
764
+ if original_mask is not None:
765
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
766
+ original_mask = np.array(original_mask)
767
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
768
+ else:
769
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
770
+ pass
771
+ else:
772
+ if resize_default:
773
+ short_side = min(output_w, output_h)
774
+ scale_ratio = 640 / short_side
775
+ output_w = int(output_w * scale_ratio)
776
+ output_h = int(output_h * scale_ratio)
777
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
778
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
779
+ original_image = np.array(original_image)
780
+ if input_mask is not None:
781
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
782
+ input_mask = np.array(input_mask)
783
+ if original_mask is not None:
784
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
785
+ original_mask = np.array(original_mask)
786
+
787
+ if input_mask.max() == 0:
788
+ original_mask = original_mask
789
+ else:
790
+ original_mask = input_mask
791
+
792
+ if original_mask is None:
793
+ raise gr.Error('Please generate mask first')
794
+
795
+ if original_mask.ndim == 2:
796
+ original_mask = original_mask[:,:,None]
797
+
798
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
799
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
800
+
801
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
802
+ masked_image = masked_image.astype(original_image.dtype)
803
+ masked_image = Image.fromarray(masked_image)
804
+
805
+ if moved_mask.max() <= 1:
806
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
807
+ original_mask = moved_mask
808
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
809
+
810
+
811
+ def move_mask_right(input_image,
812
+ original_image,
813
+ original_mask,
814
+ moving_pixels,
815
+ resize_default,
816
+ aspect_ratio_name):
817
+ alpha_mask = input_image["layers"][0].split()[3]
818
+ input_mask = np.asarray(alpha_mask)
819
+
820
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
821
+ if output_w == "" or output_h == "":
822
+ output_h, output_w = original_image.shape[:2]
823
+ if resize_default:
824
+ short_side = min(output_w, output_h)
825
+ scale_ratio = 640 / short_side
826
+ output_w = int(output_w * scale_ratio)
827
+ output_h = int(output_h * scale_ratio)
828
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
829
+ original_image = np.array(original_image)
830
+ if input_mask is not None:
831
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
832
+ input_mask = np.array(input_mask)
833
+ if original_mask is not None:
834
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
835
+ original_mask = np.array(original_mask)
836
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
837
+ else:
838
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
839
+ pass
840
+ else:
841
+ if resize_default:
842
+ short_side = min(output_w, output_h)
843
+ scale_ratio = 640 / short_side
844
+ output_w = int(output_w * scale_ratio)
845
+ output_h = int(output_h * scale_ratio)
846
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
847
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
848
+ original_image = np.array(original_image)
849
+ if input_mask is not None:
850
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
851
+ input_mask = np.array(input_mask)
852
+ if original_mask is not None:
853
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
854
+ original_mask = np.array(original_mask)
855
+
856
+ if input_mask.max() == 0:
857
+ original_mask = original_mask
858
+ else:
859
+ original_mask = input_mask
860
+
861
+ if original_mask is None:
862
+ raise gr.Error('Please generate mask first')
863
+
864
+ if original_mask.ndim == 2:
865
+ original_mask = original_mask[:,:,None]
866
+
867
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
868
+
869
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
870
+
871
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
872
+ masked_image = masked_image.astype(original_image.dtype)
873
+ masked_image = Image.fromarray(masked_image)
874
+
875
+
876
+ if moved_mask.max() <= 1:
877
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
878
+ original_mask = moved_mask
879
+
880
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
881
+
882
+
883
+ def move_mask_up(input_image,
884
+ original_image,
885
+ original_mask,
886
+ moving_pixels,
887
+ resize_default,
888
+ aspect_ratio_name):
889
+ alpha_mask = input_image["layers"][0].split()[3]
890
+ input_mask = np.asarray(alpha_mask)
891
+
892
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
893
+ if output_w == "" or output_h == "":
894
+ output_h, output_w = original_image.shape[:2]
895
+ if resize_default:
896
+ short_side = min(output_w, output_h)
897
+ scale_ratio = 640 / short_side
898
+ output_w = int(output_w * scale_ratio)
899
+ output_h = int(output_h * scale_ratio)
900
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
901
+ original_image = np.array(original_image)
902
+ if input_mask is not None:
903
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
904
+ input_mask = np.array(input_mask)
905
+ if original_mask is not None:
906
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
907
+ original_mask = np.array(original_mask)
908
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
909
+ else:
910
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
911
+ pass
912
+ else:
913
+ if resize_default:
914
+ short_side = min(output_w, output_h)
915
+ scale_ratio = 640 / short_side
916
+ output_w = int(output_w * scale_ratio)
917
+ output_h = int(output_h * scale_ratio)
918
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
919
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
920
+ original_image = np.array(original_image)
921
+ if input_mask is not None:
922
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
923
+ input_mask = np.array(input_mask)
924
+ if original_mask is not None:
925
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
926
+ original_mask = np.array(original_mask)
927
+
928
+ if input_mask.max() == 0:
929
+ original_mask = original_mask
930
+ else:
931
+ original_mask = input_mask
932
+
933
+ if original_mask is None:
934
+ raise gr.Error('Please generate mask first')
935
+
936
+ if original_mask.ndim == 2:
937
+ original_mask = original_mask[:,:,None]
938
+
939
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
940
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
941
+
942
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
943
+ masked_image = masked_image.astype(original_image.dtype)
944
+ masked_image = Image.fromarray(masked_image)
945
+
946
+ if moved_mask.max() <= 1:
947
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
948
+ original_mask = moved_mask
949
+
950
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
951
+
952
+
953
+ def move_mask_down(input_image,
954
+ original_image,
955
+ original_mask,
956
+ moving_pixels,
957
+ resize_default,
958
+ aspect_ratio_name):
959
+ alpha_mask = input_image["layers"][0].split()[3]
960
+ input_mask = np.asarray(alpha_mask)
961
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
962
+ if output_w == "" or output_h == "":
963
+ output_h, output_w = original_image.shape[:2]
964
+ if resize_default:
965
+ short_side = min(output_w, output_h)
966
+ scale_ratio = 640 / short_side
967
+ output_w = int(output_w * scale_ratio)
968
+ output_h = int(output_h * scale_ratio)
969
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
970
+ original_image = np.array(original_image)
971
+ if input_mask is not None:
972
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
973
+ input_mask = np.array(input_mask)
974
+ if original_mask is not None:
975
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
976
+ original_mask = np.array(original_mask)
977
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
978
+ else:
979
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
980
+ pass
981
+ else:
982
+ if resize_default:
983
+ short_side = min(output_w, output_h)
984
+ scale_ratio = 640 / short_side
985
+ output_w = int(output_w * scale_ratio)
986
+ output_h = int(output_h * scale_ratio)
987
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
988
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
989
+ original_image = np.array(original_image)
990
+ if input_mask is not None:
991
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
992
+ input_mask = np.array(input_mask)
993
+ if original_mask is not None:
994
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
995
+ original_mask = np.array(original_mask)
996
+
997
+ if input_mask.max() == 0:
998
+ original_mask = original_mask
999
+ else:
1000
+ original_mask = input_mask
1001
+
1002
+ if original_mask is None:
1003
+ raise gr.Error('Please generate mask first')
1004
+
1005
+ if original_mask.ndim == 2:
1006
+ original_mask = original_mask[:,:,None]
1007
+
1008
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1009
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1010
+
1011
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1012
+ masked_image = masked_image.astype(original_image.dtype)
1013
+ masked_image = Image.fromarray(masked_image)
1014
+
1015
+ if moved_mask.max() <= 1:
1016
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1017
+ original_mask = moved_mask
1018
+
1019
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1020
+
1021
+
1022
+ def invert_mask(input_image,
1023
+ original_image,
1024
+ original_mask,
1025
+ ):
1026
+ alpha_mask = input_image["layers"][0].split()[3]
1027
+ input_mask = np.asarray(alpha_mask)
1028
+ if input_mask.max() == 0:
1029
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1030
+ else:
1031
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1032
+
1033
+ if original_mask is None:
1034
+ raise gr.Error('Please generate mask first')
1035
+
1036
+ original_mask = original_mask.squeeze()
1037
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1038
+
1039
+ if original_mask.ndim == 2:
1040
+ original_mask = original_mask[:,:,None]
1041
+
1042
+ if original_mask.max() <= 1:
1043
+ original_mask = (original_mask * 255).astype(np.uint8)
1044
+
1045
+ masked_image = original_image * (1 - (original_mask>0))
1046
+ masked_image = masked_image.astype(original_image.dtype)
1047
+ masked_image = Image.fromarray(masked_image)
1048
+
1049
+ return [masked_image], [mask_image], original_mask, True
1050
+
1051
+
1052
+ def init_img(base,
1053
+ init_type,
1054
+ prompt,
1055
+ aspect_ratio,
1056
+ example_change_times
1057
+ ):
1058
+ image_pil = base["background"].convert("RGB")
1059
+ original_image = np.array(image_pil)
1060
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1061
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1062
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1063
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1064
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1065
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1066
+ width, height = image_pil.size
1067
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1068
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1069
+ image_pil = image_pil.resize((width_new, height_new))
1070
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1071
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1072
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1073
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1074
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1075
+ else:
1076
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1077
+ aspect_ratio = "Custom resolution"
1078
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1079
+
1080
+
1081
+ def reset_func(input_image,
1082
+ original_image,
1083
+ original_mask,
1084
+ prompt,
1085
+ target_prompt,
1086
+ ):
1087
+ input_image = None
1088
+ original_image = None
1089
+ original_mask = None
1090
+ prompt = ''
1091
+ mask_gallery = []
1092
+ masked_gallery = []
1093
+ result_gallery = []
1094
+ target_prompt = ''
1095
+ if torch.cuda.is_available():
1096
+ torch.cuda.empty_cache()
1097
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1098
+
1099
+
1100
+ def update_example(example_type,
1101
+ prompt,
1102
+ example_change_times):
1103
+ input_image = INPUT_IMAGE_PATH[example_type]
1104
+ image_pil = Image.open(input_image).convert("RGB")
1105
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1106
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1107
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1108
+ width, height = image_pil.size
1109
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1110
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1111
+ image_pil = image_pil.resize((width_new, height_new))
1112
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1113
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1114
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1115
+
1116
+ original_image = np.array(image_pil)
1117
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1118
+ aspect_ratio = "Custom resolution"
1119
+ example_change_times += 1
1120
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1121
+
1122
+
1123
+ def generate_target_prompt(input_image,
1124
+ original_image,
1125
+ prompt):
1126
+ # load example image
1127
+ if isinstance(original_image, str):
1128
+ original_image = input_image
1129
+
1130
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1131
+ vlm_processor,
1132
+ vlm_model,
1133
+ original_image,
1134
+ prompt,
1135
+ device)
1136
+ return prompt_after_apply_instruction
1137
+
1138
+
1139
+
1140
+
1141
+ def process_mask(input_image,
1142
+ original_image,
1143
+ prompt,
1144
+ resize_default,
1145
+ aspect_ratio_name):
1146
+ if original_image is None:
1147
+ raise gr.Error('Please upload the input image')
1148
+ if prompt is None:
1149
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
1150
+
1151
+ ## load mask
1152
+ alpha_mask = input_image["layers"][0].split()[3]
1153
+ input_mask = np.array(alpha_mask)
1154
+
1155
+ # load example image
1156
+ if isinstance(original_image, str):
1157
+ original_image = input_image["background"]
1158
+
1159
+ if input_mask.max() == 0:
1160
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
1161
+
1162
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
1163
+ vlm_model,
1164
+ original_image,
1165
+ category,
1166
+ prompt,
1167
+ device)
1168
+ # original mask: h,w,1 [0, 255]
1169
+ original_mask = vlm_response_mask(
1170
+ vlm_processor,
1171
+ vlm_model,
1172
+ category,
1173
+ original_image,
1174
+ prompt,
1175
+ object_wait_for_edit,
1176
+ sam,
1177
+ sam_predictor,
1178
+ sam_automask_generator,
1179
+ groundingdino_model,
1180
+ device).astype(np.uint8)
1181
+ else:
1182
+ original_mask = input_mask.astype(np.uint8)
1183
+ category = None
1184
+
1185
+ ## resize mask if needed
1186
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1187
+ if output_w == "" or output_h == "":
1188
+ output_h, output_w = original_image.shape[:2]
1189
+ if resize_default:
1190
+ short_side = min(output_w, output_h)
1191
+ scale_ratio = 640 / short_side
1192
+ output_w = int(output_w * scale_ratio)
1193
+ output_h = int(output_h * scale_ratio)
1194
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1195
+ original_image = np.array(original_image)
1196
+ if input_mask is not None:
1197
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1198
+ input_mask = np.array(input_mask)
1199
+ if original_mask is not None:
1200
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1201
+ original_mask = np.array(original_mask)
1202
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1203
+ else:
1204
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1205
+ pass
1206
+ else:
1207
+ if resize_default:
1208
+ short_side = min(output_w, output_h)
1209
+ scale_ratio = 640 / short_side
1210
+ output_w = int(output_w * scale_ratio)
1211
+ output_h = int(output_h * scale_ratio)
1212
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1213
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1214
+ original_image = np.array(original_image)
1215
+ if input_mask is not None:
1216
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1217
+ input_mask = np.array(input_mask)
1218
+ if original_mask is not None:
1219
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1220
+ original_mask = np.array(original_mask)
1221
+
1222
+
1223
+ if original_mask.ndim == 2:
1224
+ original_mask = original_mask[:,:,None]
1225
+
1226
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
1227
+
1228
+ masked_image = original_image * (1 - (original_mask>0))
1229
+ masked_image = masked_image.astype(np.uint8)
1230
+ masked_image = Image.fromarray(masked_image)
1231
+
1232
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
1233
+
1234
+
1235
+
1236
+ def process(input_image,
1237
+ original_image,
1238
+ original_mask,
1239
+ prompt,
1240
+ negative_prompt,
1241
+ control_strength,
1242
+ seed,
1243
+ randomize_seed,
1244
+ guidance_scale,
1245
+ num_inference_steps,
1246
+ num_samples,
1247
+ blending,
1248
+ category,
1249
+ target_prompt,
1250
+ resize_default,
1251
+ aspect_ratio_name,
1252
+ invert_mask_state):
1253
+ if original_image is None:
1254
+ if input_image is None:
1255
+ raise gr.Error('Please upload the input image')
1256
+ else:
1257
+ image_pil = input_image["background"].convert("RGB")
1258
+ original_image = np.array(image_pil)
1259
+ if prompt is None or prompt == "":
1260
+ if target_prompt is None or target_prompt == "":
1261
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
1262
+
1263
+ alpha_mask = input_image["layers"][0].split()[3]
1264
+ input_mask = np.asarray(alpha_mask)
1265
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1266
+ if output_w == "" or output_h == "":
1267
+ output_h, output_w = original_image.shape[:2]
1268
+
1269
+ if resize_default:
1270
+ short_side = min(output_w, output_h)
1271
+ scale_ratio = 640 / short_side
1272
+ output_w = int(output_w * scale_ratio)
1273
+ output_h = int(output_h * scale_ratio)
1274
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1275
+ original_image = np.array(original_image)
1276
+ if input_mask is not None:
1277
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1278
+ input_mask = np.array(input_mask)
1279
+ if original_mask is not None:
1280
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1281
+ original_mask = np.array(original_mask)
1282
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1283
+ else:
1284
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1285
+ pass
1286
+ else:
1287
+ if resize_default:
1288
+ short_side = min(output_w, output_h)
1289
+ scale_ratio = 640 / short_side
1290
+ output_w = int(output_w * scale_ratio)
1291
+ output_h = int(output_h * scale_ratio)
1292
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1293
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1294
+ original_image = np.array(original_image)
1295
+ if input_mask is not None:
1296
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1297
+ input_mask = np.array(input_mask)
1298
+ if original_mask is not None:
1299
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1300
+ original_mask = np.array(original_mask)
1301
+
1302
+ if invert_mask_state:
1303
+ original_mask = original_mask
1304
+ else:
1305
+ if input_mask.max() == 0:
1306
+ original_mask = original_mask
1307
+ else:
1308
+ original_mask = input_mask
1309
+
1310
+
1311
+ # inpainting directly if target_prompt is not None
1312
+ if category is not None:
1313
+ pass
1314
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
1315
+ pass
1316
+ else:
1317
+ try:
1318
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
1319
+ except Exception as e:
1320
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1321
+
1322
+
1323
+ if original_mask is not None:
1324
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
1325
+ else:
1326
+ try:
1327
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
1328
+ vlm_processor,
1329
+ vlm_model,
1330
+ original_image,
1331
+ category,
1332
+ prompt,
1333
+ device)
1334
+
1335
+ original_mask = vlm_response_mask(vlm_processor,
1336
+ vlm_model,
1337
+ category,
1338
+ original_image,
1339
+ prompt,
1340
+ object_wait_for_edit,
1341
+ sam,
1342
+ sam_predictor,
1343
+ sam_automask_generator,
1344
+ groundingdino_model,
1345
+ device).astype(np.uint8)
1346
+ except Exception as e:
1347
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1348
+
1349
+ if original_mask.ndim == 2:
1350
+ original_mask = original_mask[:,:,None]
1351
+
1352
+
1353
+ if target_prompt is not None and len(target_prompt) >= 1:
1354
+ prompt_after_apply_instruction = target_prompt
1355
+
1356
+ else:
1357
+ try:
1358
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1359
+ vlm_processor,
1360
+ vlm_model,
1361
+ original_image,
1362
+ prompt,
1363
+ device)
1364
+ except Exception as e:
1365
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1366
+
1367
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
1368
+
1369
+
1370
+ with torch.autocast(device):
1371
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
1372
+ prompt_after_apply_instruction,
1373
+ original_mask,
1374
+ original_image,
1375
+ generator,
1376
+ num_inference_steps,
1377
+ guidance_scale,
1378
+ control_strength,
1379
+ negative_prompt,
1380
+ num_samples,
1381
+ blending)
1382
+ original_image = np.array(init_image_np)
1383
+ masked_image = original_image * (1 - (mask_np>0))
1384
+ masked_image = masked_image.astype(np.uint8)
1385
+ masked_image = Image.fromarray(masked_image)
1386
+ # Save the images (optional)
1387
+ # import uuid
1388
+ # uuid = str(uuid.uuid4())
1389
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
1390
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
1391
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
1392
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
1393
+ # mask_image.save(f"outputs/mask_{uuid}.png")
1394
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
1395
+ # gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
1396
+ return image, [mask_image], [masked_image], prompt, '', False
1397
+
1398
+
1399
+ # 新增事件处理函数
1400
+ def generate_blip_description(input_image):
1401
+ if input_image is None:
1402
+ return "", "Input image cannot be None"
1403
+ try:
1404
+ image_pil = input_image["background"].convert("RGB")
1405
+ except KeyError:
1406
+ return "", "Input image missing 'background' key"
1407
+ except AttributeError as e:
1408
+ return "", f"Invalid image object: {str(e)}"
1409
+ try:
1410
+ description = generate_caption(blip_processor, blip_model, image_pil, device)
1411
+ return description, description # 同时更新state和显示组件
1412
+ except Exception as e:
1413
+ return "", f"Caption generation failed: {str(e)}"
1414
+
1415
+ from app.utils.utils import generate_caption
1416
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
1417
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
1418
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
1419
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
1420
+
1421
+
1422
+ def submit_GPT4o_KEY(GPT4o_KEY):
1423
+ global vlm_model, vlm_processor
1424
+ if vlm_model is not None:
1425
+ del vlm_model
1426
+ torch.cuda.empty_cache()
1427
+ try:
1428
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
1429
+ vlm_processor = ""
1430
+ response = vlm_model.chat.completions.create(
1431
+ model="deepseek-chat",
1432
+ messages=[
1433
+ {"role": "system", "content": "You are a helpful assistant."},
1434
+ {"role": "user", "content": "Hello."}
1435
+ ]
1436
+ )
1437
+ response_str = response.choices[0].message.content
1438
+
1439
+ return "Success. " + response_str, "GPT4-o (Highly Recommended)"
1440
+ except Exception as e:
1441
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
1442
+
1443
+
1444
+ def verify_deepseek_api():
1445
+ try:
1446
+ response = llm_model.chat.completions.create(
1447
+ model="deepseek-chat",
1448
+ messages=[
1449
+ {"role": "system", "content": "You are a helpful assistant."},
1450
+ {"role": "user", "content": "Hello."}
1451
+ ]
1452
+ )
1453
+ response_str = response.choices[0].message.content
1454
+
1455
+ return True, "Success. " + response_str
1456
+
1457
+ except Exception as e:
1458
+ return False, "Invalid DeepSeek API Key"
1459
+
1460
+
1461
+ def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
1462
+ try:
1463
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
1464
+ response = llm_model.chat.completions.create(
1465
+ model="deepseek-chat",
1466
+ messages=messages
1467
+ )
1468
+ response_str = response.choices[0].message.content
1469
+ return response_str
1470
+ except Exception as e:
1471
+ raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1472
+
1473
+
1474
+ def llm_decomposed_prompt_after_apply_instruction(integrated_query):
1475
+ try:
1476
+ messages = create_decomposed_query_messages_deepseek(integrated_query)
1477
+ response = llm_model.chat.completions.create(
1478
+ model="deepseek-chat",
1479
+ messages=messages
1480
+ )
1481
+ response_str = response.choices[0].message.content
1482
+ return response_str
1483
+ except Exception as e:
1484
+ raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1485
+
1486
+
1487
+ def enhance_description(blip_description, prompt):
1488
+ try:
1489
+ if not prompt or not blip_description:
1490
+ print("Empty prompt or blip_description detected")
1491
+ return "", ""
1492
+
1493
+ print(f"Enhancing with prompt: {prompt}")
1494
+ enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
1495
+ return enhanced_description, enhanced_description
1496
+
1497
+ except Exception as e:
1498
+ print(f"Enhancement failed: {str(e)}")
1499
+ return "Error occurred", "Error occurred"
1500
+
1501
+
1502
+ def decompose_description(enhanced_description):
1503
+ try:
1504
+ if not enhanced_description:
1505
+ print("Empty enhanced_description detected")
1506
+ return "", ""
1507
+
1508
+ print(f"Decomposing the enhanced description: {enhanced_description}")
1509
+ decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
1510
+ return decomposed_description, decomposed_description
1511
+
1512
+ except Exception as e:
1513
+ print(f"Decomposition failed: {str(e)}")
1514
+ return "Error occurred", "Error occurred"
1515
+
1516
+
1517
+ @torch.no_grad()
1518
+ def mix_and_search(enhanced_text: str, gallery_images: list):
1519
+ # 获取最新生成的图像元组
1520
+ latest_item = gallery_images[-1] if gallery_images else None
1521
+
1522
+ # 初始化特征列表
1523
+ features = []
1524
+
1525
+ # 图像特征提取
1526
+ if latest_item and isinstance(latest_item, tuple):
1527
+ try:
1528
+ image_path = latest_item[0]
1529
+ pil_image = Image.open(image_path).convert("RGB")
1530
+
1531
+ # 使用 CLIPProcessor 处理图像
1532
+ image_inputs = clip_processor(
1533
+ images=pil_image,
1534
+ return_tensors="pt"
1535
+ ).to(device)
1536
+
1537
+ image_features = clip_model.get_image_features(**image_inputs)
1538
+ features.append(F.normalize(image_features, dim=-1))
1539
+ except Exception as e:
1540
+ print(f"图像处理失败: {str(e)}")
1541
+
1542
+ # 文本特征提取
1543
+ if enhanced_text.strip():
1544
+ text_inputs = clip_processor(
1545
+ text=enhanced_text,
1546
+ return_tensors="pt",
1547
+ padding=True,
1548
+ truncation=True
1549
+ ).to(device)
1550
+
1551
+ text_features = clip_model.get_text_features(**text_inputs)
1552
+ features.append(F.normalize(text_features, dim=-1))
1553
+
1554
+ if not features:
1555
+ return []
1556
+
1557
+
1558
+ # 特征融合与检索
1559
+ mixed = sum(features) / len(features)
1560
+ mixed = F.normalize(mixed, dim=-1)
1561
+
1562
+ # 加载Faiss索引和图片路径映射
1563
+ index_path = "/home/zt/data/open-images/train/knn.index"
1564
+ input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
1565
+ base_image_dir = Path("/home/zt/data/open-images/train/")
1566
+
1567
+ # 按文件名中的数字排序并直接读取parquet文件
1568
+ parquet_files = sorted(
1569
+ input_data_dir.glob('*.parquet'),
1570
+ key=lambda x: int(x.stem.split("_")[-1])
1571
+ )
1572
+
1573
+ # 合并所有parquet数据
1574
+ dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
1575
+ df = pd.concat(dfs, ignore_index=True)
1576
+ image_paths = df["image_path"].tolist()
1577
+
1578
+ # 读取Faiss索引
1579
+ index = faiss.read_index(index_path)
1580
+ assert mixed.shape[1] == index.d, "特征维度不匹配"
1581
+
1582
+ # 执行检索
1583
+ mixed = mixed.cpu().detach().numpy().astype('float32')
1584
+ distances, indices = index.search(mixed, 50)
1585
+
1586
+ # 获取并验证图片路径
1587
+ retrieved_images = []
1588
+ for idx in indices[0]:
1589
+ if 0 <= idx < len(image_paths):
1590
+ img_path = base_image_dir / image_paths[idx]
1591
+ try:
1592
+ if img_path.exists():
1593
+ retrieved_images.append(Image.open(img_path).convert("RGB"))
1594
+ else:
1595
+ print(f"警告:文件缺失 {img_path}")
1596
+ except Exception as e:
1597
+ print(f"图片加载失败: {str(e)}")
1598
+
1599
+ return retrieved_images if retrieved_images else ([])
1600
+
1601
+
1602
+
1603
+ block = gr.Blocks(
1604
+ theme=gr.themes.Soft(
1605
+ radius_size=gr.themes.sizes.radius_none,
1606
+ text_size=gr.themes.sizes.text_md
1607
+ )
1608
+ )
1609
+
1610
+ with block as demo:
1611
+ with gr.Row():
1612
+ with gr.Column():
1613
+ gr.HTML(head)
1614
+ gr.Markdown(descriptions)
1615
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1616
+ with gr.Row(equal_height=True):
1617
+ gr.Markdown(instructions)
1618
+
1619
+ original_image = gr.State(value=None)
1620
+ original_mask = gr.State(value=None)
1621
+ category = gr.State(value=None)
1622
+ status = gr.State(value=None)
1623
+ invert_mask_state = gr.State(value=False)
1624
+ example_change_times = gr.State(value=0)
1625
+ deepseek_verified = gr.State(value=False)
1626
+ blip_description = gr.State(value="")
1627
+ enhanced_description = gr.State(value="")
1628
+ decomposed_description = gr.State(value="")
1629
+
1630
+ with gr.Row():
1631
+ with gr.Column():
1632
+ with gr.Group():
1633
+ input_image = gr.ImageEditor(
1634
+ label="参考图像",
1635
+ type="pil",
1636
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1637
+ layers = False,
1638
+ interactive=True,
1639
+ # height=1024,
1640
+ height=420,
1641
+ sources=["upload"],
1642
+ placeholder="🫧 点击此处或下面的图标上传图像 🫧",
1643
+ )
1644
+ prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期...", value="",lines=1)
1645
+
1646
+ with gr.Group():
1647
+ mask_button = gr.Button("💎 掩膜生成")
1648
+ with gr.Row():
1649
+ invert_mask_button = gr.Button("👐 掩膜翻转")
1650
+ random_mask_button = gr.Button("⭕️ 随机掩膜")
1651
+ with gr.Row():
1652
+ masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
1653
+ mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
1654
+
1655
+
1656
+
1657
+ with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
1658
+ dilation_size = gr.Slider(
1659
+ label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
1660
+ )
1661
+ with gr.Row():
1662
+ dilation_mask_button = gr.Button("放大掩膜")
1663
+ erosion_mask_button = gr.Button("缩小掩膜")
1664
+
1665
+ moving_pixels = gr.Slider(
1666
+ label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
1667
+ )
1668
+ with gr.Row():
1669
+ move_left_button = gr.Button("左移")
1670
+ move_right_button = gr.Button("右移")
1671
+ with gr.Row():
1672
+ move_up_button = gr.Button("上移")
1673
+ move_down_button = gr.Button("下移")
1674
+
1675
+
1676
+
1677
+ with gr.Column():
1678
+ with gr.Row():
1679
+ deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=2, container=False)
1680
+ verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
1681
+ blip_output = gr.Textbox(label="1. 原图描述(BLIP生成)", placeholder="🖼️ 上传图片后自动生成图片描述...", lines=2, interactive=True)
1682
+ with gr.Row():
1683
+ enhanced_output = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击右侧按钮生成增强描述...")
1684
+ enhance_button = gr.Button("✨ 智能整合")
1685
+
1686
+ with gr.Row():
1687
+ decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述...")
1688
+ decompose_button = gr.Button("🔧 结构分解")
1689
+
1690
+
1691
+
1692
+ with gr.Group():
1693
+ run_button = gr.Button("💫 图像编辑")
1694
+ result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
1695
+
1696
+ with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
1697
+ vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1698
+
1699
+ with gr.Group():
1700
+ with gr.Row():
1701
+ # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1702
+ GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1703
+ GPT4o_KEY_submit = gr.Button("🔑 验证密钥")
1704
+
1705
+ aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1706
+ resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
1707
+ base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1708
+ negative_prompt = gr.Text(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
1709
+ control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
1710
+ with gr.Group():
1711
+ seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
1712
+ randomize_seed = gr.Checkbox(label="随机种子", value=False)
1713
+ blending = gr.Checkbox(label="混合模式", value=True)
1714
+ num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
1715
+ with gr.Group():
1716
+ with gr.Row():
1717
+ guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
1718
+ num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
1719
+ target_prompt = gr.Text(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
1720
+
1721
+
1722
+
1723
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1724
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1725
+
1726
+ with gr.Row():
1727
+ reset_button = gr.Button("Reset")
1728
+ retrieve_button = gr.Button("🔍 开始检索")
1729
+
1730
+ with gr.Row():
1731
+ retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=800)
1732
+
1733
+
1734
+ with gr.Row():
1735
+ example = gr.Examples(
1736
+ label="Quick Example",
1737
+ examples=EXAMPLES,
1738
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1739
+ examples_per_page=10,
1740
+ cache_examples=False,
1741
+ )
1742
+
1743
+
1744
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1745
+ with gr.Row(equal_height=True):
1746
+ gr.Markdown(tips)
1747
+
1748
+ with gr.Row():
1749
+ gr.Markdown(citation)
1750
+
1751
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1752
+ ## And we need to solve the conflict between the upload and change example functions.
1753
+ input_image.upload(
1754
+ init_img,
1755
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1756
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1757
+
1758
+ )
1759
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1760
+
1761
+
1762
+ ## vlm and base model dropdown
1763
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1764
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1765
+
1766
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1767
+
1768
+
1769
+ ips=[input_image,
1770
+ original_image,
1771
+ original_mask,
1772
+ prompt,
1773
+ negative_prompt,
1774
+ control_strength,
1775
+ seed,
1776
+ randomize_seed,
1777
+ guidance_scale,
1778
+ num_inference_steps,
1779
+ num_samples,
1780
+ blending,
1781
+ category,
1782
+ target_prompt,
1783
+ resize_default,
1784
+ aspect_ratio,
1785
+ invert_mask_state]
1786
+
1787
+ ## run brushedit
1788
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1789
+
1790
+
1791
+ ## mask func
1792
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1793
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1794
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1795
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1796
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1797
+
1798
+ ## reset func
1799
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1800
+
1801
+ input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
1802
+ verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
1803
+ enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
1804
+ decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
1805
+ retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
1806
+
1807
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1808
+
1809
+
brushedit_app_new_doable.py ADDED
@@ -0,0 +1,1860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+ from pathlib import Path
9
+ import pandas as pd
10
+ import concurrent.futures
11
+ import faiss
12
+ import gradio as gr
13
+
14
+ from PIL import Image
15
+
16
+ import torch.nn.functional as F # 新增此行
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+ from scipy.ndimage import binary_dilation, binary_erosion
19
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
20
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
21
+
22
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
23
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
24
+ from diffusers.image_processor import VaeImageProcessor
25
+
26
+
27
+ from app.src.vlm_pipeline import (
28
+ vlm_response_editing_type,
29
+ vlm_response_object_wait_for_edit,
30
+ vlm_response_mask,
31
+ vlm_response_prompt_after_apply_instruction
32
+ )
33
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
34
+ from app.utils.utils import load_grounding_dino_model
35
+
36
+ from app.src.vlm_template import vlms_template
37
+ from app.src.base_model_template import base_models_template
38
+ from app.src.aspect_ratio_template import aspect_ratios
39
+
40
+ from openai import OpenAI
41
+ base_openai_url = "https://api.deepseek.com/"
42
+ base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
43
+
44
+
45
+ from transformers import BlipProcessor, BlipForConditionalGeneration
46
+ from transformers import CLIPProcessor, CLIPModel
47
+
48
+ from app.deepseek.instructions import (
49
+ create_apply_editing_messages_deepseek,
50
+ create_decomposed_query_messages_deepseek
51
+ )
52
+ from clip_retrieval.clip_client import ClipClient
53
+
54
+ #### Description ####
55
+ logo = r"""
56
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
57
+ """
58
+ head = r"""
59
+ <div style="text-align: center;">
60
+ <h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
61
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
62
+ <a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
63
+ <a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
64
+ <a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
65
+
66
+ </div>
67
+ </br>
68
+ </div>
69
+ """
70
+ descriptions = r"""
71
+ Demo for ZS-CIR"""
72
+
73
+ instructions = r"""
74
+ Demo for ZS-CIR"""
75
+
76
+ tips = r"""
77
+ Demo for ZS-CIR
78
+
79
+ """
80
+
81
+
82
+
83
+ citation = r"""
84
+ Demo for ZS-CIR"""
85
+
86
+ # - - - - - examples - - - - - #
87
+ EXAMPLES = [
88
+
89
+ [
90
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
91
+ "add a magic hat on frog head.",
92
+ 642087011,
93
+ "frog",
94
+ "frog",
95
+ True,
96
+ False,
97
+ "GPT4-o (Highly Recommended)"
98
+ ],
99
+ [
100
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
101
+ "replace the background to ancient China.",
102
+ 648464818,
103
+ "chinese_girl",
104
+ "chinese_girl",
105
+ True,
106
+ False,
107
+ "GPT4-o (Highly Recommended)"
108
+ ],
109
+ [
110
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
111
+ "remove the deer.",
112
+ 648464818,
113
+ "angel_christmas",
114
+ "angel_christmas",
115
+ False,
116
+ False,
117
+ "GPT4-o (Highly Recommended)"
118
+ ],
119
+ [
120
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
121
+ "add a wreath on head.",
122
+ 648464818,
123
+ "sunflower_girl",
124
+ "sunflower_girl",
125
+ True,
126
+ False,
127
+ "GPT4-o (Highly Recommended)"
128
+ ],
129
+ [
130
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
131
+ "add a butterfly fairy.",
132
+ 648464818,
133
+ "girl_on_sun",
134
+ "girl_on_sun",
135
+ True,
136
+ False,
137
+ "GPT4-o (Highly Recommended)"
138
+ ],
139
+ [
140
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
141
+ "remove the christmas hat.",
142
+ 642087011,
143
+ "spider_man_rm",
144
+ "spider_man_rm",
145
+ False,
146
+ False,
147
+ "GPT4-o (Highly Recommended)"
148
+ ],
149
+ [
150
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
151
+ "remove the flower.",
152
+ 642087011,
153
+ "anime_flower",
154
+ "anime_flower",
155
+ False,
156
+ False,
157
+ "GPT4-o (Highly Recommended)"
158
+ ],
159
+ [
160
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
161
+ "replace the clothes to a delicated floral skirt.",
162
+ 648464818,
163
+ "chenduling",
164
+ "chenduling",
165
+ True,
166
+ False,
167
+ "GPT4-o (Highly Recommended)"
168
+ ],
169
+ [
170
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
171
+ "make the hedgehog in Italy.",
172
+ 648464818,
173
+ "hedgehog_rp_bg",
174
+ "hedgehog_rp_bg",
175
+ True,
176
+ False,
177
+ "GPT4-o (Highly Recommended)"
178
+ ],
179
+
180
+ ]
181
+
182
+ INPUT_IMAGE_PATH = {
183
+ "frog": "./assets/frog/frog.jpeg",
184
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
185
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
186
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
187
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
188
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
189
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
190
+ "chenduling": "./assets/chenduling/chengduling.jpg",
191
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
192
+ }
193
+ MASK_IMAGE_PATH = {
194
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
195
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
196
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
197
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
198
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
199
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
200
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
201
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
202
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
203
+ }
204
+ MASKED_IMAGE_PATH = {
205
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
206
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
207
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
208
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
209
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
210
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
211
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
212
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
213
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
214
+ }
215
+ OUTPUT_IMAGE_PATH = {
216
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
217
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
218
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
219
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
220
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
221
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
222
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
223
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
224
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
225
+ }
226
+
227
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
228
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
229
+
230
+ VLM_MODEL_NAMES = list(vlms_template.keys())
231
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
232
+
233
+
234
+ BASE_MODELS = list(base_models_template.keys())
235
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
236
+
237
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
238
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
239
+
240
+
241
+ ## init device
242
+ try:
243
+ if torch.cuda.is_available():
244
+ device = "cuda"
245
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
246
+ device = "mps"
247
+ else:
248
+ device = "cpu"
249
+ except:
250
+ device = "cpu"
251
+
252
+ # ## init torch dtype
253
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
254
+ # torch_dtype = torch.bfloat16
255
+ # else:
256
+ # torch_dtype = torch.float16
257
+
258
+ # if device == "mps":
259
+ # torch_dtype = torch.float16
260
+
261
+ torch_dtype = torch.float16
262
+
263
+
264
+
265
+ # download hf models
266
+ BrushEdit_path = "models/"
267
+ if not os.path.exists(BrushEdit_path):
268
+ BrushEdit_path = snapshot_download(
269
+ repo_id="TencentARC/BrushEdit",
270
+ local_dir=BrushEdit_path,
271
+ token=os.getenv("HF_TOKEN"),
272
+ )
273
+
274
+ ## init default VLM
275
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
276
+ if vlm_processor != "" and vlm_model != "":
277
+ vlm_model.to(device)
278
+ else:
279
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
280
+
281
+ ## init default LLM
282
+ llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
283
+
284
+ ## init base model
285
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
286
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
287
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
288
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
289
+
290
+
291
+ # input brushnetX ckpt path
292
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
293
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
294
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
295
+ )
296
+ # speed up diffusion process with faster scheduler and memory optimization
297
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
298
+ # remove following line if xformers is not installed or when using Torch 2.0.
299
+ # pipe.enable_xformers_memory_efficient_attention()
300
+ pipe.enable_model_cpu_offload()
301
+
302
+
303
+ ## init SAM
304
+ sam = build_sam(checkpoint=sam_path)
305
+ sam.to(device=device)
306
+ sam_predictor = SamPredictor(sam)
307
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
308
+
309
+ ## init groundingdino_model
310
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
311
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
312
+
313
+ ## Ordinary function
314
+ def crop_and_resize(image: Image.Image,
315
+ target_width: int,
316
+ target_height: int) -> Image.Image:
317
+ """
318
+ Crops and resizes an image while preserving the aspect ratio.
319
+
320
+ Args:
321
+ image (Image.Image): Input PIL image to be cropped and resized.
322
+ target_width (int): Target width of the output image.
323
+ target_height (int): Target height of the output image.
324
+
325
+ Returns:
326
+ Image.Image: Cropped and resized image.
327
+ """
328
+ # Original dimensions
329
+ original_width, original_height = image.size
330
+ original_aspect = original_width / original_height
331
+ target_aspect = target_width / target_height
332
+
333
+ # Calculate crop box to maintain aspect ratio
334
+ if original_aspect > target_aspect:
335
+ # Crop horizontally
336
+ new_width = int(original_height * target_aspect)
337
+ new_height = original_height
338
+ left = (original_width - new_width) / 2
339
+ top = 0
340
+ right = left + new_width
341
+ bottom = original_height
342
+ else:
343
+ # Crop vertically
344
+ new_width = original_width
345
+ new_height = int(original_width / target_aspect)
346
+ left = 0
347
+ top = (original_height - new_height) / 2
348
+ right = original_width
349
+ bottom = top + new_height
350
+
351
+ # Crop and resize
352
+ cropped_image = image.crop((left, top, right, bottom))
353
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
354
+ return resized_image
355
+
356
+
357
+ ## Ordinary function
358
+ def resize(image: Image.Image,
359
+ target_width: int,
360
+ target_height: int) -> Image.Image:
361
+ """
362
+ Crops and resizes an image while preserving the aspect ratio.
363
+
364
+ Args:
365
+ image (Image.Image): Input PIL image to be cropped and resized.
366
+ target_width (int): Target width of the output image.
367
+ target_height (int): Target height of the output image.
368
+
369
+ Returns:
370
+ Image.Image: Cropped and resized image.
371
+ """
372
+ # Original dimensions
373
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
374
+ return resized_image
375
+
376
+
377
+ def move_mask_func(mask, direction, units):
378
+ binary_mask = mask.squeeze()>0
379
+ rows, cols = binary_mask.shape
380
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
381
+
382
+ if direction == 'down':
383
+ # move down
384
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
385
+
386
+ elif direction == 'up':
387
+ # move up
388
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
389
+
390
+ elif direction == 'right':
391
+ # move left
392
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
393
+
394
+ elif direction == 'left':
395
+ # move right
396
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
397
+
398
+ return moved_mask
399
+
400
+
401
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
402
+ # Randomly select the size of dilation
403
+ binary_mask = mask.squeeze()>0
404
+
405
+ if dilation_type == 'square_dilation':
406
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
407
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
408
+ elif dilation_type == 'square_erosion':
409
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
410
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
411
+ elif dilation_type == 'bounding_box':
412
+ # find the most left top and left bottom point
413
+ rows, cols = np.where(binary_mask)
414
+ if len(rows) == 0 or len(cols) == 0:
415
+ return mask # return original mask if no valid points
416
+
417
+ min_row = np.min(rows)
418
+ max_row = np.max(rows)
419
+ min_col = np.min(cols)
420
+ max_col = np.max(cols)
421
+
422
+ # create a bounding box
423
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
424
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
425
+
426
+ elif dilation_type == 'bounding_ellipse':
427
+ # find the most left top and left bottom point
428
+ rows, cols = np.where(binary_mask)
429
+ if len(rows) == 0 or len(cols) == 0:
430
+ return mask # return original mask if no valid points
431
+
432
+ min_row = np.min(rows)
433
+ max_row = np.max(rows)
434
+ min_col = np.min(cols)
435
+ max_col = np.max(cols)
436
+
437
+ # calculate the center and axis length of the ellipse
438
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
439
+ a = (max_col - min_col) // 2 # half long axis
440
+ b = (max_row - min_row) // 2 # half short axis
441
+
442
+ # create a bounding ellipse
443
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
444
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
445
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
446
+ dilated_mask[ellipse_mask] = True
447
+ else:
448
+ ValueError("dilation_type must be 'square' or 'ellipse'")
449
+
450
+ # use binary dilation
451
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
452
+ return dilated_mask
453
+
454
+
455
+ ## Gradio component function
456
+ def update_vlm_model(vlm_name):
457
+ global vlm_model, vlm_processor
458
+ if vlm_model is not None:
459
+ del vlm_model
460
+ torch.cuda.empty_cache()
461
+
462
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
463
+
464
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
465
+ if vlm_type == "llava-next":
466
+ if vlm_processor != "" and vlm_model != "":
467
+ vlm_model.to(device)
468
+ return vlm_model_dropdown
469
+ else:
470
+ if os.path.exists(vlm_local_path):
471
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
472
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
473
+ else:
474
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
475
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
476
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
477
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
478
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
479
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
480
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
481
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
482
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
483
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
484
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
485
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
486
+ elif vlm_name == "llava-next-72b-hf (Preload)":
487
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
488
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
489
+ elif vlm_type == "qwen2-vl":
490
+ if vlm_processor != "" and vlm_model != "":
491
+ vlm_model.to(device)
492
+ return vlm_model_dropdown
493
+ else:
494
+ if os.path.exists(vlm_local_path):
495
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
496
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
497
+ else:
498
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
499
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
500
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
501
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
502
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
503
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
504
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
505
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
506
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
507
+ elif vlm_type == "openai":
508
+ pass
509
+ return "success"
510
+
511
+
512
+ def update_base_model(base_model_name):
513
+ global pipe
514
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
515
+ if pipe is not None:
516
+ del pipe
517
+ torch.cuda.empty_cache()
518
+ base_model_path, pipe = base_models_template[base_model_name]
519
+ if pipe != "":
520
+ pipe.to(device)
521
+ else:
522
+ if os.path.exists(base_model_path):
523
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
524
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
525
+ )
526
+ # pipe.enable_xformers_memory_efficient_attention()
527
+ pipe.enable_model_cpu_offload()
528
+ else:
529
+ raise gr.Error(f"The base model {base_model_name} does not exist")
530
+ return "success"
531
+
532
+
533
+ def process_random_mask(input_image,
534
+ original_image,
535
+ original_mask,
536
+ resize_default,
537
+ aspect_ratio_name,
538
+ ):
539
+
540
+ alpha_mask = input_image["layers"][0].split()[3]
541
+ input_mask = np.asarray(alpha_mask)
542
+
543
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
544
+ if output_w == "" or output_h == "":
545
+ output_h, output_w = original_image.shape[:2]
546
+ if resize_default:
547
+ short_side = min(output_w, output_h)
548
+ scale_ratio = 640 / short_side
549
+ output_w = int(output_w * scale_ratio)
550
+ output_h = int(output_h * scale_ratio)
551
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
552
+ original_image = np.array(original_image)
553
+ if input_mask is not None:
554
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
555
+ input_mask = np.array(input_mask)
556
+ if original_mask is not None:
557
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
558
+ original_mask = np.array(original_mask)
559
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
560
+ else:
561
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
562
+ pass
563
+ else:
564
+ if resize_default:
565
+ short_side = min(output_w, output_h)
566
+ scale_ratio = 640 / short_side
567
+ output_w = int(output_w * scale_ratio)
568
+ output_h = int(output_h * scale_ratio)
569
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
570
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
571
+ original_image = np.array(original_image)
572
+ if input_mask is not None:
573
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
574
+ input_mask = np.array(input_mask)
575
+ if original_mask is not None:
576
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
577
+ original_mask = np.array(original_mask)
578
+
579
+
580
+ if input_mask.max() == 0:
581
+ original_mask = original_mask
582
+ else:
583
+ original_mask = input_mask
584
+
585
+ if original_mask is None:
586
+ raise gr.Error('Please generate mask first')
587
+
588
+ if original_mask.ndim == 2:
589
+ original_mask = original_mask[:,:,None]
590
+
591
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
592
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
593
+
594
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
595
+
596
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
597
+ masked_image = masked_image.astype(original_image.dtype)
598
+ masked_image = Image.fromarray(masked_image)
599
+
600
+
601
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
602
+
603
+
604
+ def process_dilation_mask(input_image,
605
+ original_image,
606
+ original_mask,
607
+ resize_default,
608
+ aspect_ratio_name,
609
+ dilation_size=20):
610
+
611
+ alpha_mask = input_image["layers"][0].split()[3]
612
+ input_mask = np.asarray(alpha_mask)
613
+
614
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
615
+ if output_w == "" or output_h == "":
616
+ output_h, output_w = original_image.shape[:2]
617
+ if resize_default:
618
+ short_side = min(output_w, output_h)
619
+ scale_ratio = 640 / short_side
620
+ output_w = int(output_w * scale_ratio)
621
+ output_h = int(output_h * scale_ratio)
622
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
623
+ original_image = np.array(original_image)
624
+ if input_mask is not None:
625
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
626
+ input_mask = np.array(input_mask)
627
+ if original_mask is not None:
628
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
629
+ original_mask = np.array(original_mask)
630
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
631
+ else:
632
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
633
+ pass
634
+ else:
635
+ if resize_default:
636
+ short_side = min(output_w, output_h)
637
+ scale_ratio = 640 / short_side
638
+ output_w = int(output_w * scale_ratio)
639
+ output_h = int(output_h * scale_ratio)
640
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
641
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
642
+ original_image = np.array(original_image)
643
+ if input_mask is not None:
644
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
645
+ input_mask = np.array(input_mask)
646
+ if original_mask is not None:
647
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
648
+ original_mask = np.array(original_mask)
649
+
650
+ if input_mask.max() == 0:
651
+ original_mask = original_mask
652
+ else:
653
+ original_mask = input_mask
654
+
655
+ if original_mask is None:
656
+ raise gr.Error('Please generate mask first')
657
+
658
+ if original_mask.ndim == 2:
659
+ original_mask = original_mask[:,:,None]
660
+
661
+ dilation_type = np.random.choice(['square_dilation'])
662
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
663
+
664
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
665
+
666
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
667
+ masked_image = masked_image.astype(original_image.dtype)
668
+ masked_image = Image.fromarray(masked_image)
669
+
670
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
671
+
672
+
673
+ def process_erosion_mask(input_image,
674
+ original_image,
675
+ original_mask,
676
+ resize_default,
677
+ aspect_ratio_name,
678
+ dilation_size=20):
679
+ alpha_mask = input_image["layers"][0].split()[3]
680
+ input_mask = np.asarray(alpha_mask)
681
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
682
+ if output_w == "" or output_h == "":
683
+ output_h, output_w = original_image.shape[:2]
684
+ if resize_default:
685
+ short_side = min(output_w, output_h)
686
+ scale_ratio = 640 / short_side
687
+ output_w = int(output_w * scale_ratio)
688
+ output_h = int(output_h * scale_ratio)
689
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
690
+ original_image = np.array(original_image)
691
+ if input_mask is not None:
692
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
693
+ input_mask = np.array(input_mask)
694
+ if original_mask is not None:
695
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
696
+ original_mask = np.array(original_mask)
697
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
698
+ else:
699
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
700
+ pass
701
+ else:
702
+ if resize_default:
703
+ short_side = min(output_w, output_h)
704
+ scale_ratio = 640 / short_side
705
+ output_w = int(output_w * scale_ratio)
706
+ output_h = int(output_h * scale_ratio)
707
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
708
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
709
+ original_image = np.array(original_image)
710
+ if input_mask is not None:
711
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
712
+ input_mask = np.array(input_mask)
713
+ if original_mask is not None:
714
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
715
+ original_mask = np.array(original_mask)
716
+
717
+ if input_mask.max() == 0:
718
+ original_mask = original_mask
719
+ else:
720
+ original_mask = input_mask
721
+
722
+ if original_mask is None:
723
+ raise gr.Error('Please generate mask first')
724
+
725
+ if original_mask.ndim == 2:
726
+ original_mask = original_mask[:,:,None]
727
+
728
+ dilation_type = np.random.choice(['square_erosion'])
729
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
730
+
731
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
732
+
733
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
734
+ masked_image = masked_image.astype(original_image.dtype)
735
+ masked_image = Image.fromarray(masked_image)
736
+
737
+
738
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
739
+
740
+
741
+ def move_mask_left(input_image,
742
+ original_image,
743
+ original_mask,
744
+ moving_pixels,
745
+ resize_default,
746
+ aspect_ratio_name):
747
+
748
+ alpha_mask = input_image["layers"][0].split()[3]
749
+ input_mask = np.asarray(alpha_mask)
750
+
751
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
752
+ if output_w == "" or output_h == "":
753
+ output_h, output_w = original_image.shape[:2]
754
+ if resize_default:
755
+ short_side = min(output_w, output_h)
756
+ scale_ratio = 640 / short_side
757
+ output_w = int(output_w * scale_ratio)
758
+ output_h = int(output_h * scale_ratio)
759
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
760
+ original_image = np.array(original_image)
761
+ if input_mask is not None:
762
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
763
+ input_mask = np.array(input_mask)
764
+ if original_mask is not None:
765
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
766
+ original_mask = np.array(original_mask)
767
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
768
+ else:
769
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
770
+ pass
771
+ else:
772
+ if resize_default:
773
+ short_side = min(output_w, output_h)
774
+ scale_ratio = 640 / short_side
775
+ output_w = int(output_w * scale_ratio)
776
+ output_h = int(output_h * scale_ratio)
777
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
778
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
779
+ original_image = np.array(original_image)
780
+ if input_mask is not None:
781
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
782
+ input_mask = np.array(input_mask)
783
+ if original_mask is not None:
784
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
785
+ original_mask = np.array(original_mask)
786
+
787
+ if input_mask.max() == 0:
788
+ original_mask = original_mask
789
+ else:
790
+ original_mask = input_mask
791
+
792
+ if original_mask is None:
793
+ raise gr.Error('Please generate mask first')
794
+
795
+ if original_mask.ndim == 2:
796
+ original_mask = original_mask[:,:,None]
797
+
798
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
799
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
800
+
801
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
802
+ masked_image = masked_image.astype(original_image.dtype)
803
+ masked_image = Image.fromarray(masked_image)
804
+
805
+ if moved_mask.max() <= 1:
806
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
807
+ original_mask = moved_mask
808
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
809
+
810
+
811
+ def move_mask_right(input_image,
812
+ original_image,
813
+ original_mask,
814
+ moving_pixels,
815
+ resize_default,
816
+ aspect_ratio_name):
817
+ alpha_mask = input_image["layers"][0].split()[3]
818
+ input_mask = np.asarray(alpha_mask)
819
+
820
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
821
+ if output_w == "" or output_h == "":
822
+ output_h, output_w = original_image.shape[:2]
823
+ if resize_default:
824
+ short_side = min(output_w, output_h)
825
+ scale_ratio = 640 / short_side
826
+ output_w = int(output_w * scale_ratio)
827
+ output_h = int(output_h * scale_ratio)
828
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
829
+ original_image = np.array(original_image)
830
+ if input_mask is not None:
831
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
832
+ input_mask = np.array(input_mask)
833
+ if original_mask is not None:
834
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
835
+ original_mask = np.array(original_mask)
836
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
837
+ else:
838
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
839
+ pass
840
+ else:
841
+ if resize_default:
842
+ short_side = min(output_w, output_h)
843
+ scale_ratio = 640 / short_side
844
+ output_w = int(output_w * scale_ratio)
845
+ output_h = int(output_h * scale_ratio)
846
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
847
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
848
+ original_image = np.array(original_image)
849
+ if input_mask is not None:
850
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
851
+ input_mask = np.array(input_mask)
852
+ if original_mask is not None:
853
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
854
+ original_mask = np.array(original_mask)
855
+
856
+ if input_mask.max() == 0:
857
+ original_mask = original_mask
858
+ else:
859
+ original_mask = input_mask
860
+
861
+ if original_mask is None:
862
+ raise gr.Error('Please generate mask first')
863
+
864
+ if original_mask.ndim == 2:
865
+ original_mask = original_mask[:,:,None]
866
+
867
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
868
+
869
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
870
+
871
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
872
+ masked_image = masked_image.astype(original_image.dtype)
873
+ masked_image = Image.fromarray(masked_image)
874
+
875
+
876
+ if moved_mask.max() <= 1:
877
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
878
+ original_mask = moved_mask
879
+
880
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
881
+
882
+
883
+ def move_mask_up(input_image,
884
+ original_image,
885
+ original_mask,
886
+ moving_pixels,
887
+ resize_default,
888
+ aspect_ratio_name):
889
+ alpha_mask = input_image["layers"][0].split()[3]
890
+ input_mask = np.asarray(alpha_mask)
891
+
892
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
893
+ if output_w == "" or output_h == "":
894
+ output_h, output_w = original_image.shape[:2]
895
+ if resize_default:
896
+ short_side = min(output_w, output_h)
897
+ scale_ratio = 640 / short_side
898
+ output_w = int(output_w * scale_ratio)
899
+ output_h = int(output_h * scale_ratio)
900
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
901
+ original_image = np.array(original_image)
902
+ if input_mask is not None:
903
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
904
+ input_mask = np.array(input_mask)
905
+ if original_mask is not None:
906
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
907
+ original_mask = np.array(original_mask)
908
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
909
+ else:
910
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
911
+ pass
912
+ else:
913
+ if resize_default:
914
+ short_side = min(output_w, output_h)
915
+ scale_ratio = 640 / short_side
916
+ output_w = int(output_w * scale_ratio)
917
+ output_h = int(output_h * scale_ratio)
918
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
919
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
920
+ original_image = np.array(original_image)
921
+ if input_mask is not None:
922
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
923
+ input_mask = np.array(input_mask)
924
+ if original_mask is not None:
925
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
926
+ original_mask = np.array(original_mask)
927
+
928
+ if input_mask.max() == 0:
929
+ original_mask = original_mask
930
+ else:
931
+ original_mask = input_mask
932
+
933
+ if original_mask is None:
934
+ raise gr.Error('Please generate mask first')
935
+
936
+ if original_mask.ndim == 2:
937
+ original_mask = original_mask[:,:,None]
938
+
939
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
940
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
941
+
942
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
943
+ masked_image = masked_image.astype(original_image.dtype)
944
+ masked_image = Image.fromarray(masked_image)
945
+
946
+ if moved_mask.max() <= 1:
947
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
948
+ original_mask = moved_mask
949
+
950
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
951
+
952
+
953
+ def move_mask_down(input_image,
954
+ original_image,
955
+ original_mask,
956
+ moving_pixels,
957
+ resize_default,
958
+ aspect_ratio_name):
959
+ alpha_mask = input_image["layers"][0].split()[3]
960
+ input_mask = np.asarray(alpha_mask)
961
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
962
+ if output_w == "" or output_h == "":
963
+ output_h, output_w = original_image.shape[:2]
964
+ if resize_default:
965
+ short_side = min(output_w, output_h)
966
+ scale_ratio = 640 / short_side
967
+ output_w = int(output_w * scale_ratio)
968
+ output_h = int(output_h * scale_ratio)
969
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
970
+ original_image = np.array(original_image)
971
+ if input_mask is not None:
972
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
973
+ input_mask = np.array(input_mask)
974
+ if original_mask is not None:
975
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
976
+ original_mask = np.array(original_mask)
977
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
978
+ else:
979
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
980
+ pass
981
+ else:
982
+ if resize_default:
983
+ short_side = min(output_w, output_h)
984
+ scale_ratio = 640 / short_side
985
+ output_w = int(output_w * scale_ratio)
986
+ output_h = int(output_h * scale_ratio)
987
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
988
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
989
+ original_image = np.array(original_image)
990
+ if input_mask is not None:
991
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
992
+ input_mask = np.array(input_mask)
993
+ if original_mask is not None:
994
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
995
+ original_mask = np.array(original_mask)
996
+
997
+ if input_mask.max() == 0:
998
+ original_mask = original_mask
999
+ else:
1000
+ original_mask = input_mask
1001
+
1002
+ if original_mask is None:
1003
+ raise gr.Error('Please generate mask first')
1004
+
1005
+ if original_mask.ndim == 2:
1006
+ original_mask = original_mask[:,:,None]
1007
+
1008
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1009
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1010
+
1011
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1012
+ masked_image = masked_image.astype(original_image.dtype)
1013
+ masked_image = Image.fromarray(masked_image)
1014
+
1015
+ if moved_mask.max() <= 1:
1016
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1017
+ original_mask = moved_mask
1018
+
1019
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1020
+
1021
+
1022
+ def invert_mask(input_image,
1023
+ original_image,
1024
+ original_mask,
1025
+ ):
1026
+ alpha_mask = input_image["layers"][0].split()[3]
1027
+ input_mask = np.asarray(alpha_mask)
1028
+ if input_mask.max() == 0:
1029
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1030
+ else:
1031
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1032
+
1033
+ if original_mask is None:
1034
+ raise gr.Error('Please generate mask first')
1035
+
1036
+ original_mask = original_mask.squeeze()
1037
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1038
+
1039
+ if original_mask.ndim == 2:
1040
+ original_mask = original_mask[:,:,None]
1041
+
1042
+ if original_mask.max() <= 1:
1043
+ original_mask = (original_mask * 255).astype(np.uint8)
1044
+
1045
+ masked_image = original_image * (1 - (original_mask>0))
1046
+ masked_image = masked_image.astype(original_image.dtype)
1047
+ masked_image = Image.fromarray(masked_image)
1048
+
1049
+ return [masked_image], [mask_image], original_mask, True
1050
+
1051
+
1052
+ def init_img(base,
1053
+ init_type,
1054
+ prompt,
1055
+ aspect_ratio,
1056
+ example_change_times
1057
+ ):
1058
+ image_pil = base["background"].convert("RGB")
1059
+ original_image = np.array(image_pil)
1060
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1061
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1062
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1063
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1064
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1065
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1066
+ width, height = image_pil.size
1067
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1068
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1069
+ image_pil = image_pil.resize((width_new, height_new))
1070
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1071
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1072
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1073
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1074
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1075
+ else:
1076
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1077
+ aspect_ratio = "Custom resolution"
1078
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1079
+
1080
+
1081
+ def reset_func(input_image,
1082
+ original_image,
1083
+ original_mask,
1084
+ prompt,
1085
+ target_prompt,
1086
+ ):
1087
+ input_image = None
1088
+ original_image = None
1089
+ original_mask = None
1090
+ prompt = ''
1091
+ mask_gallery = []
1092
+ masked_gallery = []
1093
+ result_gallery = []
1094
+ target_prompt = ''
1095
+ if torch.cuda.is_available():
1096
+ torch.cuda.empty_cache()
1097
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1098
+
1099
+
1100
+ def update_example(example_type,
1101
+ prompt,
1102
+ example_change_times):
1103
+ input_image = INPUT_IMAGE_PATH[example_type]
1104
+ image_pil = Image.open(input_image).convert("RGB")
1105
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1106
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1107
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1108
+ width, height = image_pil.size
1109
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1110
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1111
+ image_pil = image_pil.resize((width_new, height_new))
1112
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1113
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1114
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1115
+
1116
+ original_image = np.array(image_pil)
1117
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1118
+ aspect_ratio = "Custom resolution"
1119
+ example_change_times += 1
1120
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1121
+
1122
+
1123
+ def generate_target_prompt(input_image,
1124
+ original_image,
1125
+ prompt):
1126
+ # load example image
1127
+ if isinstance(original_image, str):
1128
+ original_image = input_image
1129
+
1130
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1131
+ vlm_processor,
1132
+ vlm_model,
1133
+ original_image,
1134
+ prompt,
1135
+ device)
1136
+ return prompt_after_apply_instruction
1137
+
1138
+
1139
+
1140
+
1141
+ def process_mask(input_image,
1142
+ original_image,
1143
+ prompt,
1144
+ resize_default,
1145
+ aspect_ratio_name):
1146
+ if original_image is None:
1147
+ raise gr.Error('Please upload the input image')
1148
+ if prompt is None:
1149
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
1150
+
1151
+ ## load mask
1152
+ alpha_mask = input_image["layers"][0].split()[3]
1153
+ input_mask = np.array(alpha_mask)
1154
+
1155
+ # load example image
1156
+ if isinstance(original_image, str):
1157
+ original_image = input_image["background"]
1158
+
1159
+ if input_mask.max() == 0:
1160
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
1161
+
1162
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
1163
+ vlm_model,
1164
+ original_image,
1165
+ category,
1166
+ prompt,
1167
+ device)
1168
+ # original mask: h,w,1 [0, 255]
1169
+ original_mask = vlm_response_mask(
1170
+ vlm_processor,
1171
+ vlm_model,
1172
+ category,
1173
+ original_image,
1174
+ prompt,
1175
+ object_wait_for_edit,
1176
+ sam,
1177
+ sam_predictor,
1178
+ sam_automask_generator,
1179
+ groundingdino_model,
1180
+ device).astype(np.uint8)
1181
+ else:
1182
+ original_mask = input_mask.astype(np.uint8)
1183
+ category = None
1184
+
1185
+ ## resize mask if needed
1186
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1187
+ if output_w == "" or output_h == "":
1188
+ output_h, output_w = original_image.shape[:2]
1189
+ if resize_default:
1190
+ short_side = min(output_w, output_h)
1191
+ scale_ratio = 640 / short_side
1192
+ output_w = int(output_w * scale_ratio)
1193
+ output_h = int(output_h * scale_ratio)
1194
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1195
+ original_image = np.array(original_image)
1196
+ if input_mask is not None:
1197
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1198
+ input_mask = np.array(input_mask)
1199
+ if original_mask is not None:
1200
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1201
+ original_mask = np.array(original_mask)
1202
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1203
+ else:
1204
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1205
+ pass
1206
+ else:
1207
+ if resize_default:
1208
+ short_side = min(output_w, output_h)
1209
+ scale_ratio = 640 / short_side
1210
+ output_w = int(output_w * scale_ratio)
1211
+ output_h = int(output_h * scale_ratio)
1212
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1213
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1214
+ original_image = np.array(original_image)
1215
+ if input_mask is not None:
1216
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1217
+ input_mask = np.array(input_mask)
1218
+ if original_mask is not None:
1219
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1220
+ original_mask = np.array(original_mask)
1221
+
1222
+
1223
+ if original_mask.ndim == 2:
1224
+ original_mask = original_mask[:,:,None]
1225
+
1226
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
1227
+
1228
+ masked_image = original_image * (1 - (original_mask>0))
1229
+ masked_image = masked_image.astype(np.uint8)
1230
+ masked_image = Image.fromarray(masked_image)
1231
+
1232
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
1233
+
1234
+
1235
+
1236
+ def process(input_image,
1237
+ original_image,
1238
+ original_mask,
1239
+ prompt,
1240
+ negative_prompt,
1241
+ control_strength,
1242
+ seed,
1243
+ randomize_seed,
1244
+ guidance_scale,
1245
+ num_inference_steps,
1246
+ num_samples,
1247
+ blending,
1248
+ category,
1249
+ target_prompt,
1250
+ resize_default,
1251
+ aspect_ratio_name,
1252
+ invert_mask_state):
1253
+ if original_image is None:
1254
+ if input_image is None:
1255
+ raise gr.Error('Please upload the input image')
1256
+ else:
1257
+ image_pil = input_image["background"].convert("RGB")
1258
+ original_image = np.array(image_pil)
1259
+ if prompt is None or prompt == "":
1260
+ if target_prompt is None or target_prompt == "":
1261
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
1262
+
1263
+ alpha_mask = input_image["layers"][0].split()[3]
1264
+ input_mask = np.asarray(alpha_mask)
1265
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1266
+ if output_w == "" or output_h == "":
1267
+ output_h, output_w = original_image.shape[:2]
1268
+
1269
+ if resize_default:
1270
+ short_side = min(output_w, output_h)
1271
+ scale_ratio = 640 / short_side
1272
+ output_w = int(output_w * scale_ratio)
1273
+ output_h = int(output_h * scale_ratio)
1274
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1275
+ original_image = np.array(original_image)
1276
+ if input_mask is not None:
1277
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1278
+ input_mask = np.array(input_mask)
1279
+ if original_mask is not None:
1280
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1281
+ original_mask = np.array(original_mask)
1282
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1283
+ else:
1284
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1285
+ pass
1286
+ else:
1287
+ if resize_default:
1288
+ short_side = min(output_w, output_h)
1289
+ scale_ratio = 640 / short_side
1290
+ output_w = int(output_w * scale_ratio)
1291
+ output_h = int(output_h * scale_ratio)
1292
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1293
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1294
+ original_image = np.array(original_image)
1295
+ if input_mask is not None:
1296
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1297
+ input_mask = np.array(input_mask)
1298
+ if original_mask is not None:
1299
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1300
+ original_mask = np.array(original_mask)
1301
+
1302
+ if invert_mask_state:
1303
+ original_mask = original_mask
1304
+ else:
1305
+ if input_mask.max() == 0:
1306
+ original_mask = original_mask
1307
+ else:
1308
+ original_mask = input_mask
1309
+
1310
+
1311
+ # inpainting directly if target_prompt is not None
1312
+ if category is not None:
1313
+ pass
1314
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
1315
+ pass
1316
+ else:
1317
+ try:
1318
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
1319
+ except Exception as e:
1320
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1321
+
1322
+
1323
+ if original_mask is not None:
1324
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
1325
+ else:
1326
+ try:
1327
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
1328
+ vlm_processor,
1329
+ vlm_model,
1330
+ original_image,
1331
+ category,
1332
+ prompt,
1333
+ device)
1334
+
1335
+ original_mask = vlm_response_mask(vlm_processor,
1336
+ vlm_model,
1337
+ category,
1338
+ original_image,
1339
+ prompt,
1340
+ object_wait_for_edit,
1341
+ sam,
1342
+ sam_predictor,
1343
+ sam_automask_generator,
1344
+ groundingdino_model,
1345
+ device).astype(np.uint8)
1346
+ except Exception as e:
1347
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1348
+
1349
+ if original_mask.ndim == 2:
1350
+ original_mask = original_mask[:,:,None]
1351
+
1352
+
1353
+ if target_prompt is not None and len(target_prompt) >= 1:
1354
+ prompt_after_apply_instruction = target_prompt
1355
+
1356
+ else:
1357
+ try:
1358
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1359
+ vlm_processor,
1360
+ vlm_model,
1361
+ original_image,
1362
+ prompt,
1363
+ device)
1364
+ except Exception as e:
1365
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
1366
+
1367
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
1368
+
1369
+
1370
+ with torch.autocast(device):
1371
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
1372
+ prompt_after_apply_instruction,
1373
+ original_mask,
1374
+ original_image,
1375
+ generator,
1376
+ num_inference_steps,
1377
+ guidance_scale,
1378
+ control_strength,
1379
+ negative_prompt,
1380
+ num_samples,
1381
+ blending)
1382
+ original_image = np.array(init_image_np)
1383
+ masked_image = original_image * (1 - (mask_np>0))
1384
+ masked_image = masked_image.astype(np.uint8)
1385
+ masked_image = Image.fromarray(masked_image)
1386
+ # Save the images (optional)
1387
+ # import uuid
1388
+ # uuid = str(uuid.uuid4())
1389
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
1390
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
1391
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
1392
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
1393
+ # mask_image.save(f"outputs/mask_{uuid}.png")
1394
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
1395
+ # gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
1396
+ return image, [mask_image], [masked_image], prompt, '', False
1397
+
1398
+
1399
+ # 新增事件处理函数
1400
+ def generate_blip_description(input_image):
1401
+ if input_image is None:
1402
+ return "", "Input image cannot be None"
1403
+ try:
1404
+ image_pil = input_image["background"].convert("RGB")
1405
+ except KeyError:
1406
+ return "", "Input image missing 'background' key"
1407
+ except AttributeError as e:
1408
+ return "", f"Invalid image object: {str(e)}"
1409
+ try:
1410
+ description = generate_caption(blip_processor, blip_model, image_pil, device)
1411
+ return description, description # 同时更新state和显示组件
1412
+ except Exception as e:
1413
+ return "", f"Caption generation failed: {str(e)}"
1414
+
1415
+ from app.utils.utils import generate_caption
1416
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
1417
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
1418
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
1419
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
1420
+
1421
+
1422
+ def submit_GPT4o_KEY(GPT4o_KEY):
1423
+ global vlm_model, vlm_processor
1424
+ if vlm_model is not None:
1425
+ del vlm_model
1426
+ torch.cuda.empty_cache()
1427
+ try:
1428
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
1429
+ vlm_processor = ""
1430
+ response = vlm_model.chat.completions.create(
1431
+ model="deepseek-chat",
1432
+ messages=[
1433
+ {"role": "system", "content": "You are a helpful assistant."},
1434
+ {"role": "user", "content": "Hello."}
1435
+ ]
1436
+ )
1437
+ response_str = response.choices[0].message.content
1438
+
1439
+ return "Success. " + response_str, "GPT4-o (Highly Recommended)"
1440
+ except Exception as e:
1441
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
1442
+
1443
+
1444
+ def verify_deepseek_api():
1445
+ try:
1446
+ response = llm_model.chat.completions.create(
1447
+ model="deepseek-chat",
1448
+ messages=[
1449
+ {"role": "system", "content": "You are a helpful assistant."},
1450
+ {"role": "user", "content": "Hello."}
1451
+ ]
1452
+ )
1453
+ response_str = response.choices[0].message.content
1454
+
1455
+ return True, "Success. " + response_str
1456
+
1457
+ except Exception as e:
1458
+ return False, "Invalid DeepSeek API Key"
1459
+
1460
+
1461
+ def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
1462
+ try:
1463
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
1464
+ response = llm_model.chat.completions.create(
1465
+ model="deepseek-chat",
1466
+ messages=messages
1467
+ )
1468
+ response_str = response.choices[0].message.content
1469
+ return response_str
1470
+ except Exception as e:
1471
+ raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1472
+
1473
+
1474
+ def llm_decomposed_prompt_after_apply_instruction(integrated_query):
1475
+ try:
1476
+ messages = create_decomposed_query_messages_deepseek(integrated_query)
1477
+ response = llm_model.chat.completions.create(
1478
+ model="deepseek-chat",
1479
+ messages=messages
1480
+ )
1481
+ response_str = response.choices[0].message.content
1482
+ return response_str
1483
+ except Exception as e:
1484
+ raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1485
+
1486
+
1487
+ def enhance_description(blip_description, prompt):
1488
+ try:
1489
+ if not prompt or not blip_description:
1490
+ print("Empty prompt or blip_description detected")
1491
+ return "", ""
1492
+
1493
+ print(f"Enhancing with prompt: {prompt}")
1494
+ enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
1495
+ return enhanced_description, enhanced_description
1496
+
1497
+ except Exception as e:
1498
+ print(f"Enhancement failed: {str(e)}")
1499
+ return "Error occurred", "Error occurred"
1500
+
1501
+
1502
+ def decompose_description(enhanced_description):
1503
+ try:
1504
+ if not enhanced_description:
1505
+ print("Empty enhanced_description detected")
1506
+ return "", ""
1507
+
1508
+ print(f"Decomposing the enhanced description: {enhanced_description}")
1509
+ decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
1510
+ return decomposed_description, decomposed_description
1511
+
1512
+ except Exception as e:
1513
+ print(f"Decomposition failed: {str(e)}")
1514
+ return "Error occurred", "Error occurred"
1515
+
1516
+
1517
+ @torch.no_grad()
1518
+ def mix_and_search(enhanced_text: str, gallery_images: list):
1519
+ # 获取最新生成的图像元组
1520
+ latest_item = gallery_images[-1] if gallery_images else None
1521
+
1522
+ # 初始化特征列表
1523
+ features = []
1524
+
1525
+ # 图像特征提取
1526
+ if latest_item and isinstance(latest_item, tuple):
1527
+ try:
1528
+ image_path = latest_item[0]
1529
+ pil_image = Image.open(image_path).convert("RGB")
1530
+
1531
+ # 使用 CLIPProcessor 处理图像
1532
+ image_inputs = clip_processor(
1533
+ images=pil_image,
1534
+ return_tensors="pt"
1535
+ ).to(device)
1536
+
1537
+ image_features = clip_model.get_image_features(**image_inputs)
1538
+ features.append(F.normalize(image_features, dim=-1))
1539
+ except Exception as e:
1540
+ print(f"图像处理失败: {str(e)}")
1541
+
1542
+ # 文本特征提取
1543
+ if enhanced_text.strip():
1544
+ text_inputs = clip_processor(
1545
+ text=enhanced_text,
1546
+ return_tensors="pt",
1547
+ padding=True,
1548
+ truncation=True
1549
+ ).to(device)
1550
+
1551
+ text_features = clip_model.get_text_features(**text_inputs)
1552
+ features.append(F.normalize(text_features, dim=-1))
1553
+
1554
+ if not features:
1555
+ return "## 错误:请先完成图像编辑并生成描述", []
1556
+
1557
+ # 特征融合与检索
1558
+ mixed = sum(features) / len(features)
1559
+ mixed = F.normalize(mixed, dim=-1)
1560
+
1561
+ # 加载Faiss索引和图片路径映射
1562
+ index_path = "/home/zt/data/open-images/train/knn.index"
1563
+ input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
1564
+ base_image_dir = Path("/home/zt/data/open-images/train/")
1565
+
1566
+ # 按文件名中的数字排序并直接读取parquet文件
1567
+ parquet_files = sorted(
1568
+ input_data_dir.glob('*.parquet'),
1569
+ key=lambda x: int(x.stem.split("_")[-1])
1570
+ )
1571
+
1572
+ # 合并所有parquet数据
1573
+ dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
1574
+ df = pd.concat(dfs, ignore_index=True)
1575
+ image_paths = df["image_path"].tolist()
1576
+
1577
+ # 读取Faiss索引
1578
+ index = faiss.read_index(index_path)
1579
+ assert mixed.shape[1] == index.d, "特征维度不匹配"
1580
+
1581
+ # 执行检索
1582
+ mixed = mixed.cpu().detach().numpy().astype('float32')
1583
+ distances, indices = index.search(mixed, 5)
1584
+
1585
+ # 获取并验证图片路径
1586
+ retrieved_images = []
1587
+ for idx in indices[0]:
1588
+ if 0 <= idx < len(image_paths):
1589
+ img_path = base_image_dir / image_paths[idx]
1590
+ try:
1591
+ if img_path.exists():
1592
+ retrieved_images.append(Image.open(img_path).convert("RGB"))
1593
+ else:
1594
+ print(f"警告:文件缺失 {img_path}")
1595
+ except Exception as e:
1596
+ print(f"图片加载失败: {str(e)}")
1597
+
1598
+ return "## 检索到以下相似图片:", retrieved_images if retrieved_images else ("## 未找到匹配的图片", [])
1599
+
1600
+
1601
+ block = gr.Blocks(
1602
+ theme=gr.themes.Soft(
1603
+ radius_size=gr.themes.sizes.radius_none,
1604
+ text_size=gr.themes.sizes.text_md
1605
+ )
1606
+ )
1607
+ with block as demo:
1608
+ with gr.Row():
1609
+ with gr.Column():
1610
+ gr.HTML(head)
1611
+
1612
+ gr.Markdown(descriptions)
1613
+
1614
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1615
+ with gr.Row(equal_height=True):
1616
+ gr.Markdown(instructions)
1617
+
1618
+ original_image = gr.State(value=None)
1619
+ original_mask = gr.State(value=None)
1620
+ category = gr.State(value=None)
1621
+ status = gr.State(value=None)
1622
+ invert_mask_state = gr.State(value=False)
1623
+ example_change_times = gr.State(value=0)
1624
+ deepseek_verified = gr.State(value=False)
1625
+ blip_description = gr.State(value="")
1626
+ enhanced_description = gr.State(value="")
1627
+ decomposed_description = gr.State(value="")
1628
+
1629
+ with gr.Row():
1630
+ with gr.Column():
1631
+ with gr.Row():
1632
+ input_image = gr.ImageEditor(
1633
+ label="参考图像",
1634
+ type="pil",
1635
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1636
+ layers = False,
1637
+ interactive=True,
1638
+ # height=1024,
1639
+ height=512,
1640
+ sources=["upload"],
1641
+ placeholder="🫧 点击此处或下面的图标上传图像 🫧",
1642
+ )
1643
+
1644
+ prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
1645
+ run_button = gr.Button("💫 图像编辑")
1646
+
1647
+ vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1648
+ with gr.Group():
1649
+ with gr.Row():
1650
+ # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1651
+ GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1652
+ GPT4o_KEY_submit = gr.Button("🙈 验证")
1653
+
1654
+
1655
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1656
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1657
+
1658
+ with gr.Row():
1659
+ mask_button = gr.Button("💎 掩膜生成")
1660
+ random_mask_button = gr.Button("Square/Circle Mask ")
1661
+
1662
+ # 在分解按钮后添加
1663
+ with gr.Group():
1664
+ with gr.Row():
1665
+ retrieve_button = gr.Button("🔍 开始检索")
1666
+ with gr.Row():
1667
+ retrieve_output = gr.Markdown(elem_id="accordion")
1668
+ with gr.Row():
1669
+ retrieve_gallery = gr.Gallery(label="🎊 检索结果",show_label=True, elem_id="gallery", preview=True, height=400) # 新增Gallery组件
1670
+
1671
+ with gr.Row():
1672
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1673
+
1674
+ target_prompt = gr.Text(
1675
+ label="Input Target Prompt",
1676
+ max_lines=5,
1677
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1678
+ value='',
1679
+ lines=2
1680
+ )
1681
+
1682
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1683
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1684
+ negative_prompt = gr.Text(
1685
+ label="Negative Prompt",
1686
+ max_lines=5,
1687
+ placeholder="Please input your negative prompt",
1688
+ value='ugly, low quality',lines=1
1689
+ )
1690
+
1691
+ control_strength = gr.Slider(
1692
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1693
+ )
1694
+ with gr.Group():
1695
+ seed = gr.Slider(
1696
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1697
+ )
1698
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1699
+
1700
+ blending = gr.Checkbox(label="Blending mode", value=True)
1701
+
1702
+
1703
+ num_samples = gr.Slider(
1704
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1705
+ )
1706
+
1707
+ with gr.Group():
1708
+ with gr.Row():
1709
+ guidance_scale = gr.Slider(
1710
+ label="Guidance scale",
1711
+ minimum=1,
1712
+ maximum=12,
1713
+ step=0.1,
1714
+ value=7.5,
1715
+ )
1716
+ num_inference_steps = gr.Slider(
1717
+ label="Number of inference steps",
1718
+ minimum=1,
1719
+ maximum=50,
1720
+ step=1,
1721
+ value=50,
1722
+ )
1723
+
1724
+
1725
+ with gr.Group(visible=True):
1726
+ # BLIP生成的描述
1727
+ blip_output = gr.Textbox(label="原图描述", placeholder="💭 BLIP生成的图像基础描述 💭", interactive=True, lines=1)
1728
+ # DeepSeek API验证
1729
+ with gr.Row():
1730
+ deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1731
+ verify_deepseek = gr.Button("🙈 验证")
1732
+ # 整合后的描述区域
1733
+ with gr.Row():
1734
+ enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3)
1735
+ enhance_button = gr.Button("✨ 整合")
1736
+ # 分解后的描述区域
1737
+ with gr.Row():
1738
+ decomposed_output = gr.Textbox(label="描述分解", placeholder="💭 DeepSeek生成的分解描述 💭", interactive=True, lines=3)
1739
+ decompose_button = gr.Button("🔧 分解")
1740
+ with gr.Row():
1741
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1742
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1743
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1744
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1745
+
1746
+ invert_mask_button = gr.Button("Invert Mask")
1747
+ dilation_size = gr.Slider(
1748
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1749
+ )
1750
+ with gr.Row():
1751
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1752
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1753
+
1754
+ moving_pixels = gr.Slider(
1755
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1756
+ )
1757
+ with gr.Row():
1758
+ move_left_button = gr.Button("Move Left")
1759
+ move_right_button = gr.Button("Move Right")
1760
+ with gr.Row():
1761
+ move_up_button = gr.Button("Move Up")
1762
+ move_down_button = gr.Button("Move Down")
1763
+
1764
+ with gr.Tab(elem_classes="feedback", label="Output"):
1765
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1766
+
1767
+ target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1768
+
1769
+ reset_button = gr.Button("Reset")
1770
+
1771
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1772
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1773
+
1774
+
1775
+
1776
+ with gr.Row():
1777
+ example = gr.Examples(
1778
+ label="Quick Example",
1779
+ examples=EXAMPLES,
1780
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1781
+ examples_per_page=10,
1782
+ cache_examples=False,
1783
+ )
1784
+
1785
+
1786
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1787
+ with gr.Row(equal_height=True):
1788
+ gr.Markdown(tips)
1789
+
1790
+ with gr.Row():
1791
+ gr.Markdown(citation)
1792
+
1793
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1794
+ ## And we need to solve the conflict between the upload and change example functions.
1795
+ input_image.upload(
1796
+ init_img,
1797
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1798
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1799
+
1800
+ )
1801
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1802
+
1803
+
1804
+ ## vlm and base model dropdown
1805
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1806
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1807
+
1808
+
1809
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1810
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1811
+
1812
+
1813
+ ips=[input_image,
1814
+ original_image,
1815
+ original_mask,
1816
+ prompt,
1817
+ negative_prompt,
1818
+ control_strength,
1819
+ seed,
1820
+ randomize_seed,
1821
+ guidance_scale,
1822
+ num_inference_steps,
1823
+ num_samples,
1824
+ blending,
1825
+ category,
1826
+ target_prompt,
1827
+ resize_default,
1828
+ aspect_ratio,
1829
+ invert_mask_state]
1830
+
1831
+ ## run brushedit
1832
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1833
+
1834
+
1835
+ ## mask func
1836
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1837
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1838
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1839
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1840
+
1841
+ ## reset func
1842
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1843
+
1844
+
1845
+
1846
+ # 绑定事件处理
1847
+ input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
1848
+ verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
1849
+ enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
1850
+ decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
1851
+ # 修改事件绑定
1852
+ retrieve_button.click(
1853
+ fn=mix_and_search,
1854
+ inputs=[enhanced_output, result_gallery],
1855
+ outputs=[retrieve_output, retrieve_gallery]
1856
+ )
1857
+
1858
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1859
+
1860
+
brushedit_app_new_jietu.py ADDED
The diff for this file is too large to render. See raw diff
 
brushedit_app_new_jietu2.py ADDED
The diff for this file is too large to render. See raw diff
 
brushedit_app_new_notqwen.py ADDED
The diff for this file is too large to render. See raw diff
 
brushedit_app_old.py ADDED
@@ -0,0 +1,1702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ # base_openai_url = ""
39
+
40
+ #### Description ####
41
+ logo = r"""
42
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
43
+ """
44
+ head = r"""
45
+ <div style="text-align: center;">
46
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
47
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
48
+ <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
49
+ <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
50
+ <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
51
+
52
+ </div>
53
+ </br>
54
+ </div>
55
+ """
56
+ descriptions = r"""
57
+ Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
58
+ 🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
59
+ """
60
+
61
+ instructions = r"""
62
+ Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
63
+
64
+ 🛠️ <b>Fully automated instruction-based editing</b>:
65
+ <ul>
66
+ <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
67
+ <li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
68
+ <li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
69
+ </ul>
70
+
71
+ 🛠️ <b>Interactive instruction-based editing</b>:
72
+ <ul>
73
+ <li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
74
+ <li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;"> to restore. </li>
75
+ <li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
76
+ <li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
77
+ </ul>
78
+
79
+ <b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
80
+
81
+ <b> We recommend zooming out in your browser for a better viewing range and experience. </b>
82
+
83
+ <b> For more detailed feature descriptions, see the bottom. </b>
84
+
85
+ ☕️ Have fun! 🎄 Wishing you a merry Christmas!
86
+ """
87
+
88
+ tips = r"""
89
+ 💡 <b>Some Tips</b>:
90
+ <ul>
91
+ <li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
92
+ <li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
93
+ <li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
94
+ </ul>
95
+
96
+ 💡 <b>Detailed Features</b>:
97
+ <ul>
98
+ <li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
99
+ <li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
100
+ <li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
101
+ <li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
102
+ <li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
103
+ <li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
104
+ <li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
105
+ <li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
106
+ <li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
107
+ <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
108
+ <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
109
+ </ul>
110
+
111
+ 💡 <b>Advanced Features</b>:
112
+ <ul>
113
+ <li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
114
+ <li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
115
+ <li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
116
+ <li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
117
+ <li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
118
+ <li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
119
+ </ul>
120
+
121
+
122
+ """
123
+
124
+
125
+
126
+ citation = r"""
127
+ If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
128
+ [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/BrushEdit?style=social)](https://github.com/TencentARC/BrushEdit)
129
+ ---
130
+ 📝 **Citation**
131
+ <br>
132
+ If our work is useful for your research, please consider citing:
133
+ ```bibtex
134
+ @misc{li2024brushedit,
135
+ title={BrushEdit: All-In-One Image Inpainting and Editing},
136
+ author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
137
+ year={2024},
138
+ eprint={2412.10316},
139
+ archivePrefix={arXiv},
140
+ primaryClass={cs.CV}
141
+ }
142
+ ```
143
+ 📧 **Contact**
144
+ <br>
145
+ If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
146
+ """
147
+
148
+ # - - - - - examples - - - - - #
149
+ EXAMPLES = [
150
+
151
+ [
152
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
153
+ "add a magic hat on frog head.",
154
+ 642087011,
155
+ "frog",
156
+ "frog",
157
+ True,
158
+ False,
159
+ "GPT4-o (Highly Recommended)"
160
+ ],
161
+ [
162
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
163
+ "replace the background to ancient China.",
164
+ 648464818,
165
+ "chinese_girl",
166
+ "chinese_girl",
167
+ True,
168
+ False,
169
+ "GPT4-o (Highly Recommended)"
170
+ ],
171
+ [
172
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
173
+ "remove the deer.",
174
+ 648464818,
175
+ "angel_christmas",
176
+ "angel_christmas",
177
+ False,
178
+ False,
179
+ "GPT4-o (Highly Recommended)"
180
+ ],
181
+ [
182
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
183
+ "add a wreath on head.",
184
+ 648464818,
185
+ "sunflower_girl",
186
+ "sunflower_girl",
187
+ True,
188
+ False,
189
+ "GPT4-o (Highly Recommended)"
190
+ ],
191
+ [
192
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
193
+ "add a butterfly fairy.",
194
+ 648464818,
195
+ "girl_on_sun",
196
+ "girl_on_sun",
197
+ True,
198
+ False,
199
+ "GPT4-o (Highly Recommended)"
200
+ ],
201
+ [
202
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
203
+ "remove the christmas hat.",
204
+ 642087011,
205
+ "spider_man_rm",
206
+ "spider_man_rm",
207
+ False,
208
+ False,
209
+ "GPT4-o (Highly Recommended)"
210
+ ],
211
+ [
212
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
213
+ "remove the flower.",
214
+ 642087011,
215
+ "anime_flower",
216
+ "anime_flower",
217
+ False,
218
+ False,
219
+ "GPT4-o (Highly Recommended)"
220
+ ],
221
+ [
222
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
223
+ "replace the clothes to a delicated floral skirt.",
224
+ 648464818,
225
+ "chenduling",
226
+ "chenduling",
227
+ True,
228
+ False,
229
+ "GPT4-o (Highly Recommended)"
230
+ ],
231
+ [
232
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
233
+ "make the hedgehog in Italy.",
234
+ 648464818,
235
+ "hedgehog_rp_bg",
236
+ "hedgehog_rp_bg",
237
+ True,
238
+ False,
239
+ "GPT4-o (Highly Recommended)"
240
+ ],
241
+
242
+ ]
243
+
244
+ INPUT_IMAGE_PATH = {
245
+ "frog": "./assets/frog/frog.jpeg",
246
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
247
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
248
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
249
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
250
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
251
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
252
+ "chenduling": "./assets/chenduling/chengduling.jpg",
253
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
254
+ }
255
+ MASK_IMAGE_PATH = {
256
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
257
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
258
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
259
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
260
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
261
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
262
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
263
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
264
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
265
+ }
266
+ MASKED_IMAGE_PATH = {
267
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
268
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
269
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
270
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
271
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
272
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
273
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
274
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
275
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
276
+ }
277
+ OUTPUT_IMAGE_PATH = {
278
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
279
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
280
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
281
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
282
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
283
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
284
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
285
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
286
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
287
+ }
288
+
289
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
290
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
291
+
292
+ VLM_MODEL_NAMES = list(vlms_template.keys())
293
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
294
+ BASE_MODELS = list(base_models_template.keys())
295
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
296
+
297
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
298
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
299
+
300
+
301
+ ## init device
302
+ try:
303
+ if torch.cuda.is_available():
304
+ device = "cuda"
305
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
306
+ device = "mps"
307
+ else:
308
+ device = "cpu"
309
+ except:
310
+ device = "cpu"
311
+
312
+ # ## init torch dtype
313
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
314
+ # torch_dtype = torch.bfloat16
315
+ # else:
316
+ # torch_dtype = torch.float16
317
+
318
+ # if device == "mps":
319
+ # torch_dtype = torch.float16
320
+
321
+ torch_dtype = torch.float16
322
+
323
+
324
+
325
+ # download hf models
326
+ BrushEdit_path = "models/"
327
+ if not os.path.exists(BrushEdit_path):
328
+ BrushEdit_path = snapshot_download(
329
+ repo_id="TencentARC/BrushEdit",
330
+ local_dir=BrushEdit_path,
331
+ token=os.getenv("HF_TOKEN"),
332
+ )
333
+
334
+ ## init default VLM
335
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
336
+ if vlm_processor != "" and vlm_model != "":
337
+ vlm_model.to(device)
338
+ else:
339
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
340
+
341
+
342
+ ## init base model
343
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
344
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
345
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
346
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
347
+
348
+
349
+ # input brushnetX ckpt path
350
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
351
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
352
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
353
+ )
354
+ # speed up diffusion process with faster scheduler and memory optimization
355
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
356
+ # remove following line if xformers is not installed or when using Torch 2.0.
357
+ # pipe.enable_xformers_memory_efficient_attention()
358
+ pipe.enable_model_cpu_offload()
359
+
360
+
361
+ ## init SAM
362
+ sam = build_sam(checkpoint=sam_path)
363
+ sam.to(device=device)
364
+ sam_predictor = SamPredictor(sam)
365
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
366
+
367
+ ## init groundingdino_model
368
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
369
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
370
+
371
+ ## Ordinary function
372
+ def crop_and_resize(image: Image.Image,
373
+ target_width: int,
374
+ target_height: int) -> Image.Image:
375
+ """
376
+ Crops and resizes an image while preserving the aspect ratio.
377
+
378
+ Args:
379
+ image (Image.Image): Input PIL image to be cropped and resized.
380
+ target_width (int): Target width of the output image.
381
+ target_height (int): Target height of the output image.
382
+
383
+ Returns:
384
+ Image.Image: Cropped and resized image.
385
+ """
386
+ # Original dimensions
387
+ original_width, original_height = image.size
388
+ original_aspect = original_width / original_height
389
+ target_aspect = target_width / target_height
390
+
391
+ # Calculate crop box to maintain aspect ratio
392
+ if original_aspect > target_aspect:
393
+ # Crop horizontally
394
+ new_width = int(original_height * target_aspect)
395
+ new_height = original_height
396
+ left = (original_width - new_width) / 2
397
+ top = 0
398
+ right = left + new_width
399
+ bottom = original_height
400
+ else:
401
+ # Crop vertically
402
+ new_width = original_width
403
+ new_height = int(original_width / target_aspect)
404
+ left = 0
405
+ top = (original_height - new_height) / 2
406
+ right = original_width
407
+ bottom = top + new_height
408
+
409
+ # Crop and resize
410
+ cropped_image = image.crop((left, top, right, bottom))
411
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
412
+ return resized_image
413
+
414
+
415
+ ## Ordinary function
416
+ def resize(image: Image.Image,
417
+ target_width: int,
418
+ target_height: int) -> Image.Image:
419
+ """
420
+ Crops and resizes an image while preserving the aspect ratio.
421
+
422
+ Args:
423
+ image (Image.Image): Input PIL image to be cropped and resized.
424
+ target_width (int): Target width of the output image.
425
+ target_height (int): Target height of the output image.
426
+
427
+ Returns:
428
+ Image.Image: Cropped and resized image.
429
+ """
430
+ # Original dimensions
431
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
432
+ return resized_image
433
+
434
+
435
+ def move_mask_func(mask, direction, units):
436
+ binary_mask = mask.squeeze()>0
437
+ rows, cols = binary_mask.shape
438
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
439
+
440
+ if direction == 'down':
441
+ # move down
442
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
443
+
444
+ elif direction == 'up':
445
+ # move up
446
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
447
+
448
+ elif direction == 'right':
449
+ # move left
450
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
451
+
452
+ elif direction == 'left':
453
+ # move right
454
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
455
+
456
+ return moved_mask
457
+
458
+
459
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
460
+ # Randomly select the size of dilation
461
+ binary_mask = mask.squeeze()>0
462
+
463
+ if dilation_type == 'square_dilation':
464
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
465
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
466
+ elif dilation_type == 'square_erosion':
467
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
468
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
469
+ elif dilation_type == 'bounding_box':
470
+ # find the most left top and left bottom point
471
+ rows, cols = np.where(binary_mask)
472
+ if len(rows) == 0 or len(cols) == 0:
473
+ return mask # return original mask if no valid points
474
+
475
+ min_row = np.min(rows)
476
+ max_row = np.max(rows)
477
+ min_col = np.min(cols)
478
+ max_col = np.max(cols)
479
+
480
+ # create a bounding box
481
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
482
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
483
+
484
+ elif dilation_type == 'bounding_ellipse':
485
+ # find the most left top and left bottom point
486
+ rows, cols = np.where(binary_mask)
487
+ if len(rows) == 0 or len(cols) == 0:
488
+ return mask # return original mask if no valid points
489
+
490
+ min_row = np.min(rows)
491
+ max_row = np.max(rows)
492
+ min_col = np.min(cols)
493
+ max_col = np.max(cols)
494
+
495
+ # calculate the center and axis length of the ellipse
496
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
497
+ a = (max_col - min_col) // 2 # half long axis
498
+ b = (max_row - min_row) // 2 # half short axis
499
+
500
+ # create a bounding ellipse
501
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
502
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
503
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
504
+ dilated_mask[ellipse_mask] = True
505
+ else:
506
+ ValueError("dilation_type must be 'square' or 'ellipse'")
507
+
508
+ # use binary dilation
509
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
510
+ return dilated_mask
511
+
512
+
513
+ ## Gradio component function
514
+ def update_vlm_model(vlm_name):
515
+ global vlm_model, vlm_processor
516
+ if vlm_model is not None:
517
+ del vlm_model
518
+ torch.cuda.empty_cache()
519
+
520
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
521
+
522
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
523
+ if vlm_type == "llava-next":
524
+ if vlm_processor != "" and vlm_model != "":
525
+ vlm_model.to(device)
526
+ return vlm_model_dropdown
527
+ else:
528
+ if os.path.exists(vlm_local_path):
529
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
530
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
531
+ else:
532
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
533
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
534
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
535
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
536
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
537
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
538
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
539
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
540
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
541
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
542
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
543
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
544
+ elif vlm_name == "llava-next-72b-hf (Preload)":
545
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
546
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
547
+ elif vlm_type == "qwen2-vl":
548
+ if vlm_processor != "" and vlm_model != "":
549
+ vlm_model.to(device)
550
+ return vlm_model_dropdown
551
+ else:
552
+ if os.path.exists(vlm_local_path):
553
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
554
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
555
+ else:
556
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
557
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
558
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
559
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
560
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
561
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
562
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
563
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
564
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
565
+ elif vlm_type == "openai":
566
+ pass
567
+ return "success"
568
+
569
+
570
+ def update_base_model(base_model_name):
571
+ global pipe
572
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
573
+ if pipe is not None:
574
+ del pipe
575
+ torch.cuda.empty_cache()
576
+ base_model_path, pipe = base_models_template[base_model_name]
577
+ if pipe != "":
578
+ pipe.to(device)
579
+ else:
580
+ if os.path.exists(base_model_path):
581
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
582
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
583
+ )
584
+ # pipe.enable_xformers_memory_efficient_attention()
585
+ pipe.enable_model_cpu_offload()
586
+ else:
587
+ raise gr.Error(f"The base model {base_model_name} does not exist")
588
+ return "success"
589
+
590
+
591
+ def submit_GPT4o_KEY(GPT4o_KEY):
592
+ global vlm_model, vlm_processor
593
+ if vlm_model is not None:
594
+ del vlm_model
595
+ torch.cuda.empty_cache()
596
+ try:
597
+ vlm_model = OpenAI(api_key=GPT4o_KEY)
598
+ vlm_processor = ""
599
+ response = vlm_model.chat.completions.create(
600
+ model="gpt-4o-2024-08-06",
601
+ messages=[
602
+ {"role": "system", "content": "You are a helpful assistant."},
603
+ {"role": "user", "content": "Say this is a test"}
604
+ ]
605
+ )
606
+ response_str = response.choices[0].message.content
607
+
608
+ return "Success, " + response_str, "GPT4-o (Highly Recommended)"
609
+ except Exception as e:
610
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
611
+
612
+
613
+
614
+ def process(input_image,
615
+ original_image,
616
+ original_mask,
617
+ prompt,
618
+ negative_prompt,
619
+ control_strength,
620
+ seed,
621
+ randomize_seed,
622
+ guidance_scale,
623
+ num_inference_steps,
624
+ num_samples,
625
+ blending,
626
+ category,
627
+ target_prompt,
628
+ resize_default,
629
+ aspect_ratio_name,
630
+ invert_mask_state):
631
+ if original_image is None:
632
+ if input_image is None:
633
+ raise gr.Error('Please upload the input image')
634
+ else:
635
+ image_pil = input_image["background"].convert("RGB")
636
+ original_image = np.array(image_pil)
637
+ if prompt is None or prompt == "":
638
+ if target_prompt is None or target_prompt == "":
639
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
640
+
641
+ alpha_mask = input_image["layers"][0].split()[3]
642
+ input_mask = np.asarray(alpha_mask)
643
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
644
+ if output_w == "" or output_h == "":
645
+ output_h, output_w = original_image.shape[:2]
646
+
647
+ if resize_default:
648
+ short_side = min(output_w, output_h)
649
+ scale_ratio = 640 / short_side
650
+ output_w = int(output_w * scale_ratio)
651
+ output_h = int(output_h * scale_ratio)
652
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
653
+ original_image = np.array(original_image)
654
+ if input_mask is not None:
655
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
656
+ input_mask = np.array(input_mask)
657
+ if original_mask is not None:
658
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
659
+ original_mask = np.array(original_mask)
660
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
661
+ else:
662
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
663
+ pass
664
+ else:
665
+ if resize_default:
666
+ short_side = min(output_w, output_h)
667
+ scale_ratio = 640 / short_side
668
+ output_w = int(output_w * scale_ratio)
669
+ output_h = int(output_h * scale_ratio)
670
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
671
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
672
+ original_image = np.array(original_image)
673
+ if input_mask is not None:
674
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
675
+ input_mask = np.array(input_mask)
676
+ if original_mask is not None:
677
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
678
+ original_mask = np.array(original_mask)
679
+
680
+ if invert_mask_state:
681
+ original_mask = original_mask
682
+ else:
683
+ if input_mask.max() == 0:
684
+ original_mask = original_mask
685
+ else:
686
+ original_mask = input_mask
687
+
688
+
689
+ ## inpainting directly if target_prompt is not None
690
+ if category is not None:
691
+ pass
692
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
693
+ pass
694
+ else:
695
+ try:
696
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
697
+ except Exception as e:
698
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
699
+
700
+
701
+ if original_mask is not None:
702
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
703
+ else:
704
+ try:
705
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
706
+ vlm_processor,
707
+ vlm_model,
708
+ original_image,
709
+ category,
710
+ prompt,
711
+ device)
712
+
713
+ original_mask = vlm_response_mask(vlm_processor,
714
+ vlm_model,
715
+ category,
716
+ original_image,
717
+ prompt,
718
+ object_wait_for_edit,
719
+ sam,
720
+ sam_predictor,
721
+ sam_automask_generator,
722
+ groundingdino_model,
723
+ device).astype(np.uint8)
724
+ except Exception as e:
725
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
726
+
727
+ if original_mask.ndim == 2:
728
+ original_mask = original_mask[:,:,None]
729
+
730
+
731
+ if target_prompt is not None and len(target_prompt) >= 1:
732
+ prompt_after_apply_instruction = target_prompt
733
+
734
+ else:
735
+ try:
736
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
737
+ vlm_processor,
738
+ vlm_model,
739
+ original_image,
740
+ prompt,
741
+ device)
742
+ except Exception as e:
743
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
744
+
745
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
746
+
747
+
748
+ with torch.autocast(device):
749
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
750
+ prompt_after_apply_instruction,
751
+ original_mask,
752
+ original_image,
753
+ generator,
754
+ num_inference_steps,
755
+ guidance_scale,
756
+ control_strength,
757
+ negative_prompt,
758
+ num_samples,
759
+ blending)
760
+ original_image = np.array(init_image_np)
761
+ masked_image = original_image * (1 - (mask_np>0))
762
+ masked_image = masked_image.astype(np.uint8)
763
+ masked_image = Image.fromarray(masked_image)
764
+ # Save the images (optional)
765
+ # import uuid
766
+ # uuid = str(uuid.uuid4())
767
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
768
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
769
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
770
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
771
+ # mask_image.save(f"outputs/mask_{uuid}.png")
772
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
773
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
774
+ return image, [mask_image], [masked_image], prompt, '', False
775
+
776
+
777
+ def generate_target_prompt(input_image,
778
+ original_image,
779
+ prompt):
780
+ # load example image
781
+ if isinstance(original_image, str):
782
+ original_image = input_image
783
+
784
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
785
+ vlm_processor,
786
+ vlm_model,
787
+ original_image,
788
+ prompt,
789
+ device)
790
+ return prompt_after_apply_instruction
791
+
792
+
793
+ def process_mask(input_image,
794
+ original_image,
795
+ prompt,
796
+ resize_default,
797
+ aspect_ratio_name):
798
+ if original_image is None:
799
+ raise gr.Error('Please upload the input image')
800
+ if prompt is None:
801
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
802
+
803
+ ## load mask
804
+ alpha_mask = input_image["layers"][0].split()[3]
805
+ input_mask = np.array(alpha_mask)
806
+
807
+ # load example image
808
+ if isinstance(original_image, str):
809
+ original_image = input_image["background"]
810
+
811
+ if input_mask.max() == 0:
812
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
813
+
814
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
815
+ vlm_model,
816
+ original_image,
817
+ category,
818
+ prompt,
819
+ device)
820
+ # original mask: h,w,1 [0, 255]
821
+ original_mask = vlm_response_mask(
822
+ vlm_processor,
823
+ vlm_model,
824
+ category,
825
+ original_image,
826
+ prompt,
827
+ object_wait_for_edit,
828
+ sam,
829
+ sam_predictor,
830
+ sam_automask_generator,
831
+ groundingdino_model,
832
+ device).astype(np.uint8)
833
+ else:
834
+ original_mask = input_mask.astype(np.uint8)
835
+ category = None
836
+
837
+ ## resize mask if needed
838
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
839
+ if output_w == "" or output_h == "":
840
+ output_h, output_w = original_image.shape[:2]
841
+ if resize_default:
842
+ short_side = min(output_w, output_h)
843
+ scale_ratio = 640 / short_side
844
+ output_w = int(output_w * scale_ratio)
845
+ output_h = int(output_h * scale_ratio)
846
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
847
+ original_image = np.array(original_image)
848
+ if input_mask is not None:
849
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
850
+ input_mask = np.array(input_mask)
851
+ if original_mask is not None:
852
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
853
+ original_mask = np.array(original_mask)
854
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
855
+ else:
856
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
857
+ pass
858
+ else:
859
+ if resize_default:
860
+ short_side = min(output_w, output_h)
861
+ scale_ratio = 640 / short_side
862
+ output_w = int(output_w * scale_ratio)
863
+ output_h = int(output_h * scale_ratio)
864
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
865
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
866
+ original_image = np.array(original_image)
867
+ if input_mask is not None:
868
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
869
+ input_mask = np.array(input_mask)
870
+ if original_mask is not None:
871
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
872
+ original_mask = np.array(original_mask)
873
+
874
+
875
+ if original_mask.ndim == 2:
876
+ original_mask = original_mask[:,:,None]
877
+
878
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
879
+
880
+ masked_image = original_image * (1 - (original_mask>0))
881
+ masked_image = masked_image.astype(np.uint8)
882
+ masked_image = Image.fromarray(masked_image)
883
+
884
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
885
+
886
+
887
+ def process_random_mask(input_image,
888
+ original_image,
889
+ original_mask,
890
+ resize_default,
891
+ aspect_ratio_name,
892
+ ):
893
+
894
+ alpha_mask = input_image["layers"][0].split()[3]
895
+ input_mask = np.asarray(alpha_mask)
896
+
897
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
898
+ if output_w == "" or output_h == "":
899
+ output_h, output_w = original_image.shape[:2]
900
+ if resize_default:
901
+ short_side = min(output_w, output_h)
902
+ scale_ratio = 640 / short_side
903
+ output_w = int(output_w * scale_ratio)
904
+ output_h = int(output_h * scale_ratio)
905
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
906
+ original_image = np.array(original_image)
907
+ if input_mask is not None:
908
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
909
+ input_mask = np.array(input_mask)
910
+ if original_mask is not None:
911
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
912
+ original_mask = np.array(original_mask)
913
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
914
+ else:
915
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
916
+ pass
917
+ else:
918
+ if resize_default:
919
+ short_side = min(output_w, output_h)
920
+ scale_ratio = 640 / short_side
921
+ output_w = int(output_w * scale_ratio)
922
+ output_h = int(output_h * scale_ratio)
923
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
924
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
925
+ original_image = np.array(original_image)
926
+ if input_mask is not None:
927
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
928
+ input_mask = np.array(input_mask)
929
+ if original_mask is not None:
930
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
931
+ original_mask = np.array(original_mask)
932
+
933
+
934
+ if input_mask.max() == 0:
935
+ original_mask = original_mask
936
+ else:
937
+ original_mask = input_mask
938
+
939
+ if original_mask is None:
940
+ raise gr.Error('Please generate mask first')
941
+
942
+ if original_mask.ndim == 2:
943
+ original_mask = original_mask[:,:,None]
944
+
945
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
946
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
947
+
948
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
949
+
950
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
951
+ masked_image = masked_image.astype(original_image.dtype)
952
+ masked_image = Image.fromarray(masked_image)
953
+
954
+
955
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
956
+
957
+
958
+ def process_dilation_mask(input_image,
959
+ original_image,
960
+ original_mask,
961
+ resize_default,
962
+ aspect_ratio_name,
963
+ dilation_size=20):
964
+
965
+ alpha_mask = input_image["layers"][0].split()[3]
966
+ input_mask = np.asarray(alpha_mask)
967
+
968
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
969
+ if output_w == "" or output_h == "":
970
+ output_h, output_w = original_image.shape[:2]
971
+ if resize_default:
972
+ short_side = min(output_w, output_h)
973
+ scale_ratio = 640 / short_side
974
+ output_w = int(output_w * scale_ratio)
975
+ output_h = int(output_h * scale_ratio)
976
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
977
+ original_image = np.array(original_image)
978
+ if input_mask is not None:
979
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
980
+ input_mask = np.array(input_mask)
981
+ if original_mask is not None:
982
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
983
+ original_mask = np.array(original_mask)
984
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
985
+ else:
986
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
987
+ pass
988
+ else:
989
+ if resize_default:
990
+ short_side = min(output_w, output_h)
991
+ scale_ratio = 640 / short_side
992
+ output_w = int(output_w * scale_ratio)
993
+ output_h = int(output_h * scale_ratio)
994
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
995
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
996
+ original_image = np.array(original_image)
997
+ if input_mask is not None:
998
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
999
+ input_mask = np.array(input_mask)
1000
+ if original_mask is not None:
1001
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1002
+ original_mask = np.array(original_mask)
1003
+
1004
+ if input_mask.max() == 0:
1005
+ original_mask = original_mask
1006
+ else:
1007
+ original_mask = input_mask
1008
+
1009
+ if original_mask is None:
1010
+ raise gr.Error('Please generate mask first')
1011
+
1012
+ if original_mask.ndim == 2:
1013
+ original_mask = original_mask[:,:,None]
1014
+
1015
+ dilation_type = np.random.choice(['square_dilation'])
1016
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1017
+
1018
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1019
+
1020
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1021
+ masked_image = masked_image.astype(original_image.dtype)
1022
+ masked_image = Image.fromarray(masked_image)
1023
+
1024
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1025
+
1026
+
1027
+ def process_erosion_mask(input_image,
1028
+ original_image,
1029
+ original_mask,
1030
+ resize_default,
1031
+ aspect_ratio_name,
1032
+ dilation_size=20):
1033
+ alpha_mask = input_image["layers"][0].split()[3]
1034
+ input_mask = np.asarray(alpha_mask)
1035
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1036
+ if output_w == "" or output_h == "":
1037
+ output_h, output_w = original_image.shape[:2]
1038
+ if resize_default:
1039
+ short_side = min(output_w, output_h)
1040
+ scale_ratio = 640 / short_side
1041
+ output_w = int(output_w * scale_ratio)
1042
+ output_h = int(output_h * scale_ratio)
1043
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1044
+ original_image = np.array(original_image)
1045
+ if input_mask is not None:
1046
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1047
+ input_mask = np.array(input_mask)
1048
+ if original_mask is not None:
1049
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1050
+ original_mask = np.array(original_mask)
1051
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1052
+ else:
1053
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1054
+ pass
1055
+ else:
1056
+ if resize_default:
1057
+ short_side = min(output_w, output_h)
1058
+ scale_ratio = 640 / short_side
1059
+ output_w = int(output_w * scale_ratio)
1060
+ output_h = int(output_h * scale_ratio)
1061
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1062
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1063
+ original_image = np.array(original_image)
1064
+ if input_mask is not None:
1065
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1066
+ input_mask = np.array(input_mask)
1067
+ if original_mask is not None:
1068
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1069
+ original_mask = np.array(original_mask)
1070
+
1071
+ if input_mask.max() == 0:
1072
+ original_mask = original_mask
1073
+ else:
1074
+ original_mask = input_mask
1075
+
1076
+ if original_mask is None:
1077
+ raise gr.Error('Please generate mask first')
1078
+
1079
+ if original_mask.ndim == 2:
1080
+ original_mask = original_mask[:,:,None]
1081
+
1082
+ dilation_type = np.random.choice(['square_erosion'])
1083
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1084
+
1085
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1086
+
1087
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1088
+ masked_image = masked_image.astype(original_image.dtype)
1089
+ masked_image = Image.fromarray(masked_image)
1090
+
1091
+
1092
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1093
+
1094
+
1095
+ def move_mask_left(input_image,
1096
+ original_image,
1097
+ original_mask,
1098
+ moving_pixels,
1099
+ resize_default,
1100
+ aspect_ratio_name):
1101
+
1102
+ alpha_mask = input_image["layers"][0].split()[3]
1103
+ input_mask = np.asarray(alpha_mask)
1104
+
1105
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1106
+ if output_w == "" or output_h == "":
1107
+ output_h, output_w = original_image.shape[:2]
1108
+ if resize_default:
1109
+ short_side = min(output_w, output_h)
1110
+ scale_ratio = 640 / short_side
1111
+ output_w = int(output_w * scale_ratio)
1112
+ output_h = int(output_h * scale_ratio)
1113
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1114
+ original_image = np.array(original_image)
1115
+ if input_mask is not None:
1116
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1117
+ input_mask = np.array(input_mask)
1118
+ if original_mask is not None:
1119
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1120
+ original_mask = np.array(original_mask)
1121
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1122
+ else:
1123
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1124
+ pass
1125
+ else:
1126
+ if resize_default:
1127
+ short_side = min(output_w, output_h)
1128
+ scale_ratio = 640 / short_side
1129
+ output_w = int(output_w * scale_ratio)
1130
+ output_h = int(output_h * scale_ratio)
1131
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1132
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1133
+ original_image = np.array(original_image)
1134
+ if input_mask is not None:
1135
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1136
+ input_mask = np.array(input_mask)
1137
+ if original_mask is not None:
1138
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1139
+ original_mask = np.array(original_mask)
1140
+
1141
+ if input_mask.max() == 0:
1142
+ original_mask = original_mask
1143
+ else:
1144
+ original_mask = input_mask
1145
+
1146
+ if original_mask is None:
1147
+ raise gr.Error('Please generate mask first')
1148
+
1149
+ if original_mask.ndim == 2:
1150
+ original_mask = original_mask[:,:,None]
1151
+
1152
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1153
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1154
+
1155
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1156
+ masked_image = masked_image.astype(original_image.dtype)
1157
+ masked_image = Image.fromarray(masked_image)
1158
+
1159
+ if moved_mask.max() <= 1:
1160
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1161
+ original_mask = moved_mask
1162
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1163
+
1164
+
1165
+ def move_mask_right(input_image,
1166
+ original_image,
1167
+ original_mask,
1168
+ moving_pixels,
1169
+ resize_default,
1170
+ aspect_ratio_name):
1171
+ alpha_mask = input_image["layers"][0].split()[3]
1172
+ input_mask = np.asarray(alpha_mask)
1173
+
1174
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1175
+ if output_w == "" or output_h == "":
1176
+ output_h, output_w = original_image.shape[:2]
1177
+ if resize_default:
1178
+ short_side = min(output_w, output_h)
1179
+ scale_ratio = 640 / short_side
1180
+ output_w = int(output_w * scale_ratio)
1181
+ output_h = int(output_h * scale_ratio)
1182
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1183
+ original_image = np.array(original_image)
1184
+ if input_mask is not None:
1185
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1186
+ input_mask = np.array(input_mask)
1187
+ if original_mask is not None:
1188
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1189
+ original_mask = np.array(original_mask)
1190
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1191
+ else:
1192
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1193
+ pass
1194
+ else:
1195
+ if resize_default:
1196
+ short_side = min(output_w, output_h)
1197
+ scale_ratio = 640 / short_side
1198
+ output_w = int(output_w * scale_ratio)
1199
+ output_h = int(output_h * scale_ratio)
1200
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1201
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1202
+ original_image = np.array(original_image)
1203
+ if input_mask is not None:
1204
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1205
+ input_mask = np.array(input_mask)
1206
+ if original_mask is not None:
1207
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1208
+ original_mask = np.array(original_mask)
1209
+
1210
+ if input_mask.max() == 0:
1211
+ original_mask = original_mask
1212
+ else:
1213
+ original_mask = input_mask
1214
+
1215
+ if original_mask is None:
1216
+ raise gr.Error('Please generate mask first')
1217
+
1218
+ if original_mask.ndim == 2:
1219
+ original_mask = original_mask[:,:,None]
1220
+
1221
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1222
+
1223
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1224
+
1225
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1226
+ masked_image = masked_image.astype(original_image.dtype)
1227
+ masked_image = Image.fromarray(masked_image)
1228
+
1229
+
1230
+ if moved_mask.max() <= 1:
1231
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1232
+ original_mask = moved_mask
1233
+
1234
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1235
+
1236
+
1237
+ def move_mask_up(input_image,
1238
+ original_image,
1239
+ original_mask,
1240
+ moving_pixels,
1241
+ resize_default,
1242
+ aspect_ratio_name):
1243
+ alpha_mask = input_image["layers"][0].split()[3]
1244
+ input_mask = np.asarray(alpha_mask)
1245
+
1246
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1247
+ if output_w == "" or output_h == "":
1248
+ output_h, output_w = original_image.shape[:2]
1249
+ if resize_default:
1250
+ short_side = min(output_w, output_h)
1251
+ scale_ratio = 640 / short_side
1252
+ output_w = int(output_w * scale_ratio)
1253
+ output_h = int(output_h * scale_ratio)
1254
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1255
+ original_image = np.array(original_image)
1256
+ if input_mask is not None:
1257
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1258
+ input_mask = np.array(input_mask)
1259
+ if original_mask is not None:
1260
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1261
+ original_mask = np.array(original_mask)
1262
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1263
+ else:
1264
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1265
+ pass
1266
+ else:
1267
+ if resize_default:
1268
+ short_side = min(output_w, output_h)
1269
+ scale_ratio = 640 / short_side
1270
+ output_w = int(output_w * scale_ratio)
1271
+ output_h = int(output_h * scale_ratio)
1272
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1273
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1274
+ original_image = np.array(original_image)
1275
+ if input_mask is not None:
1276
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1277
+ input_mask = np.array(input_mask)
1278
+ if original_mask is not None:
1279
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1280
+ original_mask = np.array(original_mask)
1281
+
1282
+ if input_mask.max() == 0:
1283
+ original_mask = original_mask
1284
+ else:
1285
+ original_mask = input_mask
1286
+
1287
+ if original_mask is None:
1288
+ raise gr.Error('Please generate mask first')
1289
+
1290
+ if original_mask.ndim == 2:
1291
+ original_mask = original_mask[:,:,None]
1292
+
1293
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1294
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1295
+
1296
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1297
+ masked_image = masked_image.astype(original_image.dtype)
1298
+ masked_image = Image.fromarray(masked_image)
1299
+
1300
+ if moved_mask.max() <= 1:
1301
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1302
+ original_mask = moved_mask
1303
+
1304
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1305
+
1306
+
1307
+ def move_mask_down(input_image,
1308
+ original_image,
1309
+ original_mask,
1310
+ moving_pixels,
1311
+ resize_default,
1312
+ aspect_ratio_name):
1313
+ alpha_mask = input_image["layers"][0].split()[3]
1314
+ input_mask = np.asarray(alpha_mask)
1315
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1316
+ if output_w == "" or output_h == "":
1317
+ output_h, output_w = original_image.shape[:2]
1318
+ if resize_default:
1319
+ short_side = min(output_w, output_h)
1320
+ scale_ratio = 640 / short_side
1321
+ output_w = int(output_w * scale_ratio)
1322
+ output_h = int(output_h * scale_ratio)
1323
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1324
+ original_image = np.array(original_image)
1325
+ if input_mask is not None:
1326
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1327
+ input_mask = np.array(input_mask)
1328
+ if original_mask is not None:
1329
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1330
+ original_mask = np.array(original_mask)
1331
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1332
+ else:
1333
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1334
+ pass
1335
+ else:
1336
+ if resize_default:
1337
+ short_side = min(output_w, output_h)
1338
+ scale_ratio = 640 / short_side
1339
+ output_w = int(output_w * scale_ratio)
1340
+ output_h = int(output_h * scale_ratio)
1341
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1342
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1343
+ original_image = np.array(original_image)
1344
+ if input_mask is not None:
1345
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1346
+ input_mask = np.array(input_mask)
1347
+ if original_mask is not None:
1348
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1349
+ original_mask = np.array(original_mask)
1350
+
1351
+ if input_mask.max() == 0:
1352
+ original_mask = original_mask
1353
+ else:
1354
+ original_mask = input_mask
1355
+
1356
+ if original_mask is None:
1357
+ raise gr.Error('Please generate mask first')
1358
+
1359
+ if original_mask.ndim == 2:
1360
+ original_mask = original_mask[:,:,None]
1361
+
1362
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1363
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1364
+
1365
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1366
+ masked_image = masked_image.astype(original_image.dtype)
1367
+ masked_image = Image.fromarray(masked_image)
1368
+
1369
+ if moved_mask.max() <= 1:
1370
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1371
+ original_mask = moved_mask
1372
+
1373
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1374
+
1375
+
1376
+ def invert_mask(input_image,
1377
+ original_image,
1378
+ original_mask,
1379
+ ):
1380
+ alpha_mask = input_image["layers"][0].split()[3]
1381
+ input_mask = np.asarray(alpha_mask)
1382
+ if input_mask.max() == 0:
1383
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1384
+ else:
1385
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1386
+
1387
+ if original_mask is None:
1388
+ raise gr.Error('Please generate mask first')
1389
+
1390
+ original_mask = original_mask.squeeze()
1391
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1392
+
1393
+ if original_mask.ndim == 2:
1394
+ original_mask = original_mask[:,:,None]
1395
+
1396
+ if original_mask.max() <= 1:
1397
+ original_mask = (original_mask * 255).astype(np.uint8)
1398
+
1399
+ masked_image = original_image * (1 - (original_mask>0))
1400
+ masked_image = masked_image.astype(original_image.dtype)
1401
+ masked_image = Image.fromarray(masked_image)
1402
+
1403
+ return [masked_image], [mask_image], original_mask, True
1404
+
1405
+
1406
+ def init_img(base,
1407
+ init_type,
1408
+ prompt,
1409
+ aspect_ratio,
1410
+ example_change_times
1411
+ ):
1412
+ image_pil = base["background"].convert("RGB")
1413
+ original_image = np.array(image_pil)
1414
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1415
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1416
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1417
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1418
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1419
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1420
+ width, height = image_pil.size
1421
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1422
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1423
+ image_pil = image_pil.resize((width_new, height_new))
1424
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1425
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1426
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1427
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1428
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1429
+ else:
1430
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1431
+ aspect_ratio = "Custom resolution"
1432
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1433
+
1434
+
1435
+ def reset_func(input_image,
1436
+ original_image,
1437
+ original_mask,
1438
+ prompt,
1439
+ target_prompt,
1440
+ ):
1441
+ input_image = None
1442
+ original_image = None
1443
+ original_mask = None
1444
+ prompt = ''
1445
+ mask_gallery = []
1446
+ masked_gallery = []
1447
+ result_gallery = []
1448
+ target_prompt = ''
1449
+ if torch.cuda.is_available():
1450
+ torch.cuda.empty_cache()
1451
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1452
+
1453
+
1454
+ def update_example(example_type,
1455
+ prompt,
1456
+ example_change_times):
1457
+ input_image = INPUT_IMAGE_PATH[example_type]
1458
+ image_pil = Image.open(input_image).convert("RGB")
1459
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1460
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1461
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1462
+ width, height = image_pil.size
1463
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1464
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1465
+ image_pil = image_pil.resize((width_new, height_new))
1466
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1467
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1468
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1469
+
1470
+ original_image = np.array(image_pil)
1471
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1472
+ aspect_ratio = "Custom resolution"
1473
+ example_change_times += 1
1474
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1475
+
1476
+
1477
+ block = gr.Blocks(
1478
+ theme=gr.themes.Soft(
1479
+ radius_size=gr.themes.sizes.radius_none,
1480
+ text_size=gr.themes.sizes.text_md
1481
+ )
1482
+ )
1483
+ with block as demo:
1484
+ with gr.Row():
1485
+ with gr.Column():
1486
+ gr.HTML(head)
1487
+
1488
+ gr.Markdown(descriptions)
1489
+
1490
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1491
+ with gr.Row(equal_height=True):
1492
+ gr.Markdown(instructions)
1493
+
1494
+ original_image = gr.State(value=None)
1495
+ original_mask = gr.State(value=None)
1496
+ category = gr.State(value=None)
1497
+ status = gr.State(value=None)
1498
+ invert_mask_state = gr.State(value=False)
1499
+ example_change_times = gr.State(value=0)
1500
+
1501
+
1502
+ with gr.Row():
1503
+ with gr.Column():
1504
+ with gr.Row():
1505
+ input_image = gr.ImageEditor(
1506
+ label="Input Image",
1507
+ type="pil",
1508
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1509
+ layers = False,
1510
+ interactive=True,
1511
+ height=1024,
1512
+ sources=["upload"],
1513
+ placeholder="Please click here or the icon below to upload the image.",
1514
+ )
1515
+
1516
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1517
+ run_button = gr.Button("💫 Run")
1518
+
1519
+ vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1520
+ with gr.Group():
1521
+ with gr.Row():
1522
+ GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1523
+
1524
+ GPT4o_KEY_submit = gr.Button("Submit and Verify")
1525
+
1526
+
1527
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1528
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1529
+
1530
+ with gr.Row():
1531
+ mask_button = gr.Button("Generate Mask")
1532
+ random_mask_button = gr.Button("Square/Circle Mask ")
1533
+
1534
+
1535
+ with gr.Row():
1536
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1537
+
1538
+ target_prompt = gr.Text(
1539
+ label="Input Target Prompt",
1540
+ max_lines=5,
1541
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1542
+ value='',
1543
+ lines=2
1544
+ )
1545
+
1546
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1547
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1548
+ negative_prompt = gr.Text(
1549
+ label="Negative Prompt",
1550
+ max_lines=5,
1551
+ placeholder="Please input your negative prompt",
1552
+ value='ugly, low quality',lines=1
1553
+ )
1554
+
1555
+ control_strength = gr.Slider(
1556
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1557
+ )
1558
+ with gr.Group():
1559
+ seed = gr.Slider(
1560
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1561
+ )
1562
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1563
+
1564
+ blending = gr.Checkbox(label="Blending mode", value=True)
1565
+
1566
+
1567
+ num_samples = gr.Slider(
1568
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1569
+ )
1570
+
1571
+ with gr.Group():
1572
+ with gr.Row():
1573
+ guidance_scale = gr.Slider(
1574
+ label="Guidance scale",
1575
+ minimum=1,
1576
+ maximum=12,
1577
+ step=0.1,
1578
+ value=7.5,
1579
+ )
1580
+ num_inference_steps = gr.Slider(
1581
+ label="Number of inference steps",
1582
+ minimum=1,
1583
+ maximum=50,
1584
+ step=1,
1585
+ value=50,
1586
+ )
1587
+
1588
+
1589
+ with gr.Column():
1590
+ with gr.Row():
1591
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1592
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1593
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1594
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1595
+
1596
+ invert_mask_button = gr.Button("Invert Mask")
1597
+ dilation_size = gr.Slider(
1598
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1599
+ )
1600
+ with gr.Row():
1601
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1602
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1603
+
1604
+ moving_pixels = gr.Slider(
1605
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1606
+ )
1607
+ with gr.Row():
1608
+ move_left_button = gr.Button("Move Left")
1609
+ move_right_button = gr.Button("Move Right")
1610
+ with gr.Row():
1611
+ move_up_button = gr.Button("Move Up")
1612
+ move_down_button = gr.Button("Move Down")
1613
+
1614
+ with gr.Tab(elem_classes="feedback", label="Output"):
1615
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1616
+
1617
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1618
+
1619
+ reset_button = gr.Button("Reset")
1620
+
1621
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1622
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1623
+
1624
+
1625
+
1626
+ with gr.Row():
1627
+ example = gr.Examples(
1628
+ label="Quick Example",
1629
+ examples=EXAMPLES,
1630
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1631
+ examples_per_page=10,
1632
+ cache_examples=False,
1633
+ )
1634
+
1635
+
1636
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1637
+ with gr.Row(equal_height=True):
1638
+ gr.Markdown(tips)
1639
+
1640
+ with gr.Row():
1641
+ gr.Markdown(citation)
1642
+
1643
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1644
+ ## And we need to solve the conflict between the upload and change example functions.
1645
+ input_image.upload(
1646
+ init_img,
1647
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1648
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1649
+ )
1650
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1651
+
1652
+ ## vlm and base model dropdown
1653
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1654
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1655
+
1656
+
1657
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1658
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1659
+
1660
+
1661
+ ips=[input_image,
1662
+ original_image,
1663
+ original_mask,
1664
+ prompt,
1665
+ negative_prompt,
1666
+ control_strength,
1667
+ seed,
1668
+ randomize_seed,
1669
+ guidance_scale,
1670
+ num_inference_steps,
1671
+ num_samples,
1672
+ blending,
1673
+ category,
1674
+ target_prompt,
1675
+ resize_default,
1676
+ aspect_ratio,
1677
+ invert_mask_state]
1678
+
1679
+ ## run brushedit
1680
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1681
+
1682
+ ## mask func
1683
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1684
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1685
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1686
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1687
+
1688
+ ## move mask func
1689
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1690
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1691
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1692
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1693
+
1694
+ ## prompt func
1695
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1696
+
1697
+ ## reset func
1698
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1699
+
1700
+ ## if have a localhost access error, try to use the following code
1701
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1702
+ # demo.launch()
brushedit_app_only_integrate.py ADDED
@@ -0,0 +1,1725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ # base_openai_url = "https://api.deepseek.com/"
39
+
40
+
41
+ from transformers import BlipProcessor, BlipForConditionalGeneration
42
+
43
+ from app.deepseek.instructions import create_apply_editing_messages_deepseek
44
+
45
+
46
+ #### Description ####
47
+ logo = r"""
48
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
49
+ """
50
+ head = r"""
51
+ <div style="text-align: center;">
52
+ <h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
53
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
54
+ <a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
55
+ <a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
56
+ <a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
57
+
58
+ </div>
59
+ </br>
60
+ </div>
61
+ """
62
+ descriptions = r"""
63
+ Demo for CIR"""
64
+
65
+ instructions = r"""
66
+ Demo for CIR"""
67
+
68
+ tips = r"""
69
+ Demo for CIR
70
+
71
+ """
72
+
73
+
74
+
75
+ citation = r"""
76
+ Demo for CIR"""
77
+
78
+ # - - - - - examples - - - - - #
79
+ EXAMPLES = [
80
+
81
+ [
82
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
83
+ "add a magic hat on frog head.",
84
+ 642087011,
85
+ "frog",
86
+ "frog",
87
+ True,
88
+ False,
89
+ "GPT4-o (Highly Recommended)"
90
+ ],
91
+ [
92
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
93
+ "replace the background to ancient China.",
94
+ 648464818,
95
+ "chinese_girl",
96
+ "chinese_girl",
97
+ True,
98
+ False,
99
+ "GPT4-o (Highly Recommended)"
100
+ ],
101
+ [
102
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
103
+ "remove the deer.",
104
+ 648464818,
105
+ "angel_christmas",
106
+ "angel_christmas",
107
+ False,
108
+ False,
109
+ "GPT4-o (Highly Recommended)"
110
+ ],
111
+ [
112
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
113
+ "add a wreath on head.",
114
+ 648464818,
115
+ "sunflower_girl",
116
+ "sunflower_girl",
117
+ True,
118
+ False,
119
+ "GPT4-o (Highly Recommended)"
120
+ ],
121
+ [
122
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
123
+ "add a butterfly fairy.",
124
+ 648464818,
125
+ "girl_on_sun",
126
+ "girl_on_sun",
127
+ True,
128
+ False,
129
+ "GPT4-o (Highly Recommended)"
130
+ ],
131
+ [
132
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
133
+ "remove the christmas hat.",
134
+ 642087011,
135
+ "spider_man_rm",
136
+ "spider_man_rm",
137
+ False,
138
+ False,
139
+ "GPT4-o (Highly Recommended)"
140
+ ],
141
+ [
142
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
143
+ "remove the flower.",
144
+ 642087011,
145
+ "anime_flower",
146
+ "anime_flower",
147
+ False,
148
+ False,
149
+ "GPT4-o (Highly Recommended)"
150
+ ],
151
+ [
152
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
153
+ "replace the clothes to a delicated floral skirt.",
154
+ 648464818,
155
+ "chenduling",
156
+ "chenduling",
157
+ True,
158
+ False,
159
+ "GPT4-o (Highly Recommended)"
160
+ ],
161
+ [
162
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
163
+ "make the hedgehog in Italy.",
164
+ 648464818,
165
+ "hedgehog_rp_bg",
166
+ "hedgehog_rp_bg",
167
+ True,
168
+ False,
169
+ "GPT4-o (Highly Recommended)"
170
+ ],
171
+
172
+ ]
173
+
174
+ INPUT_IMAGE_PATH = {
175
+ "frog": "./assets/frog/frog.jpeg",
176
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
177
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
178
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
179
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
180
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
181
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
182
+ "chenduling": "./assets/chenduling/chengduling.jpg",
183
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
184
+ }
185
+ MASK_IMAGE_PATH = {
186
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
187
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
188
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
189
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
190
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
191
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
192
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
193
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
194
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
195
+ }
196
+ MASKED_IMAGE_PATH = {
197
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
198
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
199
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
200
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
201
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
202
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
203
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
204
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
205
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
206
+ }
207
+ OUTPUT_IMAGE_PATH = {
208
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
209
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
210
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
211
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
212
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
213
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
214
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
215
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
216
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
217
+ }
218
+
219
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
220
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
221
+
222
+ VLM_MODEL_NAMES = list(vlms_template.keys())
223
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
224
+
225
+
226
+ BASE_MODELS = list(base_models_template.keys())
227
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
228
+
229
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
230
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
231
+
232
+
233
+ ## init device
234
+ try:
235
+ if torch.cuda.is_available():
236
+ device = "cuda"
237
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
238
+ device = "mps"
239
+ else:
240
+ device = "cpu"
241
+ except:
242
+ device = "cpu"
243
+
244
+ # ## init torch dtype
245
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
246
+ # torch_dtype = torch.bfloat16
247
+ # else:
248
+ # torch_dtype = torch.float16
249
+
250
+ # if device == "mps":
251
+ # torch_dtype = torch.float16
252
+
253
+ torch_dtype = torch.float16
254
+
255
+
256
+
257
+ # download hf models
258
+ BrushEdit_path = "models/"
259
+ if not os.path.exists(BrushEdit_path):
260
+ BrushEdit_path = snapshot_download(
261
+ repo_id="TencentARC/BrushEdit",
262
+ local_dir=BrushEdit_path,
263
+ token=os.getenv("HF_TOKEN"),
264
+ )
265
+
266
+ ## init default VLM
267
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
268
+ if vlm_processor != "" and vlm_model != "":
269
+ vlm_model.to(device)
270
+ else:
271
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
272
+
273
+ def initialize_llm_model():
274
+ global llm_model
275
+ llm_model = OpenAI(api_key="sk-d145b963a92649a88843caeb741e8bbc", base_url="https://api.deepseek.com")
276
+
277
+ ## init base model
278
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
279
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
280
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
281
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
282
+
283
+
284
+ # input brushnetX ckpt path
285
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
286
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
287
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
288
+ )
289
+ # speed up diffusion process with faster scheduler and memory optimization
290
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
291
+ # remove following line if xformers is not installed or when using Torch 2.0.
292
+ # pipe.enable_xformers_memory_efficient_attention()
293
+ pipe.enable_model_cpu_offload()
294
+
295
+
296
+ ## init SAM
297
+ sam = build_sam(checkpoint=sam_path)
298
+ sam.to(device=device)
299
+ sam_predictor = SamPredictor(sam)
300
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
301
+
302
+ ## init groundingdino_model
303
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
304
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
305
+
306
+ ## Ordinary function
307
+ def crop_and_resize(image: Image.Image,
308
+ target_width: int,
309
+ target_height: int) -> Image.Image:
310
+ """
311
+ Crops and resizes an image while preserving the aspect ratio.
312
+
313
+ Args:
314
+ image (Image.Image): Input PIL image to be cropped and resized.
315
+ target_width (int): Target width of the output image.
316
+ target_height (int): Target height of the output image.
317
+
318
+ Returns:
319
+ Image.Image: Cropped and resized image.
320
+ """
321
+ # Original dimensions
322
+ original_width, original_height = image.size
323
+ original_aspect = original_width / original_height
324
+ target_aspect = target_width / target_height
325
+
326
+ # Calculate crop box to maintain aspect ratio
327
+ if original_aspect > target_aspect:
328
+ # Crop horizontally
329
+ new_width = int(original_height * target_aspect)
330
+ new_height = original_height
331
+ left = (original_width - new_width) / 2
332
+ top = 0
333
+ right = left + new_width
334
+ bottom = original_height
335
+ else:
336
+ # Crop vertically
337
+ new_width = original_width
338
+ new_height = int(original_width / target_aspect)
339
+ left = 0
340
+ top = (original_height - new_height) / 2
341
+ right = original_width
342
+ bottom = top + new_height
343
+
344
+ # Crop and resize
345
+ cropped_image = image.crop((left, top, right, bottom))
346
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
347
+ return resized_image
348
+
349
+
350
+ ## Ordinary function
351
+ def resize(image: Image.Image,
352
+ target_width: int,
353
+ target_height: int) -> Image.Image:
354
+ """
355
+ Crops and resizes an image while preserving the aspect ratio.
356
+
357
+ Args:
358
+ image (Image.Image): Input PIL image to be cropped and resized.
359
+ target_width (int): Target width of the output image.
360
+ target_height (int): Target height of the output image.
361
+
362
+ Returns:
363
+ Image.Image: Cropped and resized image.
364
+ """
365
+ # Original dimensions
366
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
367
+ return resized_image
368
+
369
+
370
+ def move_mask_func(mask, direction, units):
371
+ binary_mask = mask.squeeze()>0
372
+ rows, cols = binary_mask.shape
373
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
374
+
375
+ if direction == 'down':
376
+ # move down
377
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
378
+
379
+ elif direction == 'up':
380
+ # move up
381
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
382
+
383
+ elif direction == 'right':
384
+ # move left
385
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
386
+
387
+ elif direction == 'left':
388
+ # move right
389
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
390
+
391
+ return moved_mask
392
+
393
+
394
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
395
+ # Randomly select the size of dilation
396
+ binary_mask = mask.squeeze()>0
397
+
398
+ if dilation_type == 'square_dilation':
399
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
400
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
401
+ elif dilation_type == 'square_erosion':
402
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
403
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
404
+ elif dilation_type == 'bounding_box':
405
+ # find the most left top and left bottom point
406
+ rows, cols = np.where(binary_mask)
407
+ if len(rows) == 0 or len(cols) == 0:
408
+ return mask # return original mask if no valid points
409
+
410
+ min_row = np.min(rows)
411
+ max_row = np.max(rows)
412
+ min_col = np.min(cols)
413
+ max_col = np.max(cols)
414
+
415
+ # create a bounding box
416
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
417
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
418
+
419
+ elif dilation_type == 'bounding_ellipse':
420
+ # find the most left top and left bottom point
421
+ rows, cols = np.where(binary_mask)
422
+ if len(rows) == 0 or len(cols) == 0:
423
+ return mask # return original mask if no valid points
424
+
425
+ min_row = np.min(rows)
426
+ max_row = np.max(rows)
427
+ min_col = np.min(cols)
428
+ max_col = np.max(cols)
429
+
430
+ # calculate the center and axis length of the ellipse
431
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
432
+ a = (max_col - min_col) // 2 # half long axis
433
+ b = (max_row - min_row) // 2 # half short axis
434
+
435
+ # create a bounding ellipse
436
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
437
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
438
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
439
+ dilated_mask[ellipse_mask] = True
440
+ else:
441
+ ValueError("dilation_type must be 'square' or 'ellipse'")
442
+
443
+ # use binary dilation
444
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
445
+ return dilated_mask
446
+
447
+
448
+ ## Gradio component function
449
+ def update_vlm_model(vlm_name):
450
+ global vlm_model, vlm_processor
451
+ if vlm_model is not None:
452
+ del vlm_model
453
+ torch.cuda.empty_cache()
454
+
455
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
456
+
457
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
458
+ if vlm_type == "llava-next":
459
+ if vlm_processor != "" and vlm_model != "":
460
+ vlm_model.to(device)
461
+ return vlm_model_dropdown
462
+ else:
463
+ if os.path.exists(vlm_local_path):
464
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
465
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
466
+ else:
467
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
468
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
469
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
470
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
471
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
472
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
473
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
474
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
475
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
476
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
477
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
478
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
479
+ elif vlm_name == "llava-next-72b-hf (Preload)":
480
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
481
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
482
+ elif vlm_type == "qwen2-vl":
483
+ if vlm_processor != "" and vlm_model != "":
484
+ vlm_model.to(device)
485
+ return vlm_model_dropdown
486
+ else:
487
+ if os.path.exists(vlm_local_path):
488
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
489
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
490
+ else:
491
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
492
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
493
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
494
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
495
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
496
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
497
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
498
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
499
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
500
+ elif vlm_type == "openai":
501
+ pass
502
+ return "success"
503
+
504
+
505
+ def update_base_model(base_model_name):
506
+ global pipe
507
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
508
+ if pipe is not None:
509
+ del pipe
510
+ torch.cuda.empty_cache()
511
+ base_model_path, pipe = base_models_template[base_model_name]
512
+ if pipe != "":
513
+ pipe.to(device)
514
+ else:
515
+ if os.path.exists(base_model_path):
516
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
517
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
518
+ )
519
+ # pipe.enable_xformers_memory_efficient_attention()
520
+ pipe.enable_model_cpu_offload()
521
+ else:
522
+ raise gr.Error(f"The base model {base_model_name} does not exist")
523
+ return "success"
524
+
525
+
526
+ def submit_GPT4o_KEY(GPT4o_KEY):
527
+ global vlm_model, vlm_processor
528
+ if vlm_model is not None:
529
+ del vlm_model
530
+ torch.cuda.empty_cache()
531
+ try:
532
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
533
+ vlm_processor = ""
534
+ response = vlm_model.chat.completions.create(
535
+ model="deepseek-chat",
536
+ messages=[
537
+ {"role": "system", "content": "You are a helpful assistant."},
538
+ {"role": "user", "content": "Hello."}
539
+ ]
540
+ )
541
+ response_str = response.choices[0].message.content
542
+
543
+ return "Success, " + response_str, "GPT4-o (Highly Recommended)"
544
+ except Exception as e:
545
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
546
+
547
+
548
+
549
+ def process(input_image,
550
+ original_image,
551
+ original_mask,
552
+ prompt,
553
+ negative_prompt,
554
+ control_strength,
555
+ seed,
556
+ randomize_seed,
557
+ guidance_scale,
558
+ num_inference_steps,
559
+ num_samples,
560
+ blending,
561
+ category,
562
+ target_prompt,
563
+ resize_default,
564
+ aspect_ratio_name,
565
+ invert_mask_state):
566
+ if original_image is None:
567
+ if input_image is None:
568
+ raise gr.Error('Please upload the input image')
569
+ else:
570
+ print("input_image的键:", input_image.keys()) # 打印字典键
571
+ image_pil = input_image["background"].convert("RGB")
572
+ original_image = np.array(image_pil)
573
+ if prompt is None or prompt == "":
574
+ if target_prompt is None or target_prompt == "":
575
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
576
+
577
+ alpha_mask = input_image["layers"][0].split()[3]
578
+ input_mask = np.asarray(alpha_mask)
579
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
580
+ if output_w == "" or output_h == "":
581
+ output_h, output_w = original_image.shape[:2]
582
+
583
+ if resize_default:
584
+ short_side = min(output_w, output_h)
585
+ scale_ratio = 640 / short_side
586
+ output_w = int(output_w * scale_ratio)
587
+ output_h = int(output_h * scale_ratio)
588
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
589
+ original_image = np.array(original_image)
590
+ if input_mask is not None:
591
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
592
+ input_mask = np.array(input_mask)
593
+ if original_mask is not None:
594
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
595
+ original_mask = np.array(original_mask)
596
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
597
+ else:
598
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
599
+ pass
600
+ else:
601
+ if resize_default:
602
+ short_side = min(output_w, output_h)
603
+ scale_ratio = 640 / short_side
604
+ output_w = int(output_w * scale_ratio)
605
+ output_h = int(output_h * scale_ratio)
606
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
607
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
608
+ original_image = np.array(original_image)
609
+ if input_mask is not None:
610
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
611
+ input_mask = np.array(input_mask)
612
+ if original_mask is not None:
613
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
614
+ original_mask = np.array(original_mask)
615
+
616
+ if invert_mask_state:
617
+ original_mask = original_mask
618
+ else:
619
+ if input_mask.max() == 0:
620
+ original_mask = original_mask
621
+ else:
622
+ original_mask = input_mask
623
+
624
+
625
+ ## inpainting directly if target_prompt is not None
626
+ if category is not None:
627
+ pass
628
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
629
+ pass
630
+ else:
631
+ try:
632
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
633
+ except Exception as e:
634
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
635
+
636
+
637
+ if original_mask is not None:
638
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
639
+ else:
640
+ try:
641
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
642
+ vlm_processor,
643
+ vlm_model,
644
+ original_image,
645
+ category,
646
+ prompt,
647
+ device)
648
+
649
+ original_mask = vlm_response_mask(vlm_processor,
650
+ vlm_model,
651
+ category,
652
+ original_image,
653
+ prompt,
654
+ object_wait_for_edit,
655
+ sam,
656
+ sam_predictor,
657
+ sam_automask_generator,
658
+ groundingdino_model,
659
+ device).astype(np.uint8)
660
+ except Exception as e:
661
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
662
+
663
+ if original_mask.ndim == 2:
664
+ original_mask = original_mask[:,:,None]
665
+
666
+
667
+ if target_prompt is not None and len(target_prompt) >= 1:
668
+ prompt_after_apply_instruction = target_prompt
669
+
670
+ else:
671
+ try:
672
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
673
+ vlm_processor,
674
+ vlm_model,
675
+ original_image,
676
+ prompt,
677
+ device)
678
+ except Exception as e:
679
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
680
+
681
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
682
+
683
+
684
+ with torch.autocast(device):
685
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
686
+ prompt_after_apply_instruction,
687
+ original_mask,
688
+ original_image,
689
+ generator,
690
+ num_inference_steps,
691
+ guidance_scale,
692
+ control_strength,
693
+ negative_prompt,
694
+ num_samples,
695
+ blending)
696
+ original_image = np.array(init_image_np)
697
+ masked_image = original_image * (1 - (mask_np>0))
698
+ masked_image = masked_image.astype(np.uint8)
699
+ masked_image = Image.fromarray(masked_image)
700
+ # Save the images (optional)
701
+ # import uuid
702
+ # uuid = str(uuid.uuid4())
703
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
704
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
705
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
706
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
707
+ # mask_image.save(f"outputs/mask_{uuid}.png")
708
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
709
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
710
+ return image, [mask_image], [masked_image], prompt, '', False
711
+
712
+
713
+ def process_mask(input_image,
714
+ original_image,
715
+ prompt,
716
+ resize_default,
717
+ aspect_ratio_name):
718
+ if original_image is None:
719
+ raise gr.Error('Please upload the input image')
720
+ if prompt is None:
721
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
722
+
723
+ ## load mask
724
+ alpha_mask = input_image["layers"][0].split()[3]
725
+ input_mask = np.array(alpha_mask)
726
+
727
+ # load example image
728
+ if isinstance(original_image, str):
729
+ original_image = input_image["background"]
730
+
731
+ if input_mask.max() == 0:
732
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
733
+
734
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
735
+ vlm_model,
736
+ original_image,
737
+ category,
738
+ prompt,
739
+ device)
740
+ # original mask: h,w,1 [0, 255]
741
+ original_mask = vlm_response_mask(
742
+ vlm_processor,
743
+ vlm_model,
744
+ category,
745
+ original_image,
746
+ prompt,
747
+ object_wait_for_edit,
748
+ sam,
749
+ sam_predictor,
750
+ sam_automask_generator,
751
+ groundingdino_model,
752
+ device).astype(np.uint8)
753
+ else:
754
+ original_mask = input_mask.astype(np.uint8)
755
+ category = None
756
+
757
+ ## resize mask if needed
758
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
759
+ if output_w == "" or output_h == "":
760
+ output_h, output_w = original_image.shape[:2]
761
+ if resize_default:
762
+ short_side = min(output_w, output_h)
763
+ scale_ratio = 640 / short_side
764
+ output_w = int(output_w * scale_ratio)
765
+ output_h = int(output_h * scale_ratio)
766
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
767
+ original_image = np.array(original_image)
768
+ if input_mask is not None:
769
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
770
+ input_mask = np.array(input_mask)
771
+ if original_mask is not None:
772
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
773
+ original_mask = np.array(original_mask)
774
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
775
+ else:
776
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
777
+ pass
778
+ else:
779
+ if resize_default:
780
+ short_side = min(output_w, output_h)
781
+ scale_ratio = 640 / short_side
782
+ output_w = int(output_w * scale_ratio)
783
+ output_h = int(output_h * scale_ratio)
784
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
785
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
786
+ original_image = np.array(original_image)
787
+ if input_mask is not None:
788
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
789
+ input_mask = np.array(input_mask)
790
+ if original_mask is not None:
791
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
792
+ original_mask = np.array(original_mask)
793
+
794
+
795
+ if original_mask.ndim == 2:
796
+ original_mask = original_mask[:,:,None]
797
+
798
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
799
+
800
+ masked_image = original_image * (1 - (original_mask>0))
801
+ masked_image = masked_image.astype(np.uint8)
802
+ masked_image = Image.fromarray(masked_image)
803
+
804
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
805
+
806
+
807
+ def process_random_mask(input_image,
808
+ original_image,
809
+ original_mask,
810
+ resize_default,
811
+ aspect_ratio_name,
812
+ ):
813
+
814
+ alpha_mask = input_image["layers"][0].split()[3]
815
+ input_mask = np.asarray(alpha_mask)
816
+
817
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
818
+ if output_w == "" or output_h == "":
819
+ output_h, output_w = original_image.shape[:2]
820
+ if resize_default:
821
+ short_side = min(output_w, output_h)
822
+ scale_ratio = 640 / short_side
823
+ output_w = int(output_w * scale_ratio)
824
+ output_h = int(output_h * scale_ratio)
825
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
826
+ original_image = np.array(original_image)
827
+ if input_mask is not None:
828
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
829
+ input_mask = np.array(input_mask)
830
+ if original_mask is not None:
831
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
832
+ original_mask = np.array(original_mask)
833
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
834
+ else:
835
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
836
+ pass
837
+ else:
838
+ if resize_default:
839
+ short_side = min(output_w, output_h)
840
+ scale_ratio = 640 / short_side
841
+ output_w = int(output_w * scale_ratio)
842
+ output_h = int(output_h * scale_ratio)
843
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
844
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
845
+ original_image = np.array(original_image)
846
+ if input_mask is not None:
847
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
848
+ input_mask = np.array(input_mask)
849
+ if original_mask is not None:
850
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
851
+ original_mask = np.array(original_mask)
852
+
853
+
854
+ if input_mask.max() == 0:
855
+ original_mask = original_mask
856
+ else:
857
+ original_mask = input_mask
858
+
859
+ if original_mask is None:
860
+ raise gr.Error('Please generate mask first')
861
+
862
+ if original_mask.ndim == 2:
863
+ original_mask = original_mask[:,:,None]
864
+
865
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
866
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
867
+
868
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
869
+
870
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
871
+ masked_image = masked_image.astype(original_image.dtype)
872
+ masked_image = Image.fromarray(masked_image)
873
+
874
+
875
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
876
+
877
+
878
+ def process_dilation_mask(input_image,
879
+ original_image,
880
+ original_mask,
881
+ resize_default,
882
+ aspect_ratio_name,
883
+ dilation_size=20):
884
+
885
+ alpha_mask = input_image["layers"][0].split()[3]
886
+ input_mask = np.asarray(alpha_mask)
887
+
888
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
889
+ if output_w == "" or output_h == "":
890
+ output_h, output_w = original_image.shape[:2]
891
+ if resize_default:
892
+ short_side = min(output_w, output_h)
893
+ scale_ratio = 640 / short_side
894
+ output_w = int(output_w * scale_ratio)
895
+ output_h = int(output_h * scale_ratio)
896
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
897
+ original_image = np.array(original_image)
898
+ if input_mask is not None:
899
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
900
+ input_mask = np.array(input_mask)
901
+ if original_mask is not None:
902
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
903
+ original_mask = np.array(original_mask)
904
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
905
+ else:
906
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
907
+ pass
908
+ else:
909
+ if resize_default:
910
+ short_side = min(output_w, output_h)
911
+ scale_ratio = 640 / short_side
912
+ output_w = int(output_w * scale_ratio)
913
+ output_h = int(output_h * scale_ratio)
914
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
915
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
916
+ original_image = np.array(original_image)
917
+ if input_mask is not None:
918
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
919
+ input_mask = np.array(input_mask)
920
+ if original_mask is not None:
921
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
922
+ original_mask = np.array(original_mask)
923
+
924
+ if input_mask.max() == 0:
925
+ original_mask = original_mask
926
+ else:
927
+ original_mask = input_mask
928
+
929
+ if original_mask is None:
930
+ raise gr.Error('Please generate mask first')
931
+
932
+ if original_mask.ndim == 2:
933
+ original_mask = original_mask[:,:,None]
934
+
935
+ dilation_type = np.random.choice(['square_dilation'])
936
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
937
+
938
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
939
+
940
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
941
+ masked_image = masked_image.astype(original_image.dtype)
942
+ masked_image = Image.fromarray(masked_image)
943
+
944
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
945
+
946
+
947
+ def process_erosion_mask(input_image,
948
+ original_image,
949
+ original_mask,
950
+ resize_default,
951
+ aspect_ratio_name,
952
+ dilation_size=20):
953
+ alpha_mask = input_image["layers"][0].split()[3]
954
+ input_mask = np.asarray(alpha_mask)
955
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
956
+ if output_w == "" or output_h == "":
957
+ output_h, output_w = original_image.shape[:2]
958
+ if resize_default:
959
+ short_side = min(output_w, output_h)
960
+ scale_ratio = 640 / short_side
961
+ output_w = int(output_w * scale_ratio)
962
+ output_h = int(output_h * scale_ratio)
963
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
964
+ original_image = np.array(original_image)
965
+ if input_mask is not None:
966
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
967
+ input_mask = np.array(input_mask)
968
+ if original_mask is not None:
969
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
970
+ original_mask = np.array(original_mask)
971
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
972
+ else:
973
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
974
+ pass
975
+ else:
976
+ if resize_default:
977
+ short_side = min(output_w, output_h)
978
+ scale_ratio = 640 / short_side
979
+ output_w = int(output_w * scale_ratio)
980
+ output_h = int(output_h * scale_ratio)
981
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
982
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
983
+ original_image = np.array(original_image)
984
+ if input_mask is not None:
985
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
986
+ input_mask = np.array(input_mask)
987
+ if original_mask is not None:
988
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
989
+ original_mask = np.array(original_mask)
990
+
991
+ if input_mask.max() == 0:
992
+ original_mask = original_mask
993
+ else:
994
+ original_mask = input_mask
995
+
996
+ if original_mask is None:
997
+ raise gr.Error('Please generate mask first')
998
+
999
+ if original_mask.ndim == 2:
1000
+ original_mask = original_mask[:,:,None]
1001
+
1002
+ dilation_type = np.random.choice(['square_erosion'])
1003
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
1004
+
1005
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
1006
+
1007
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
1008
+ masked_image = masked_image.astype(original_image.dtype)
1009
+ masked_image = Image.fromarray(masked_image)
1010
+
1011
+
1012
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
1013
+
1014
+
1015
+ def move_mask_left(input_image,
1016
+ original_image,
1017
+ original_mask,
1018
+ moving_pixels,
1019
+ resize_default,
1020
+ aspect_ratio_name):
1021
+
1022
+ alpha_mask = input_image["layers"][0].split()[3]
1023
+ input_mask = np.asarray(alpha_mask)
1024
+
1025
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1026
+ if output_w == "" or output_h == "":
1027
+ output_h, output_w = original_image.shape[:2]
1028
+ if resize_default:
1029
+ short_side = min(output_w, output_h)
1030
+ scale_ratio = 640 / short_side
1031
+ output_w = int(output_w * scale_ratio)
1032
+ output_h = int(output_h * scale_ratio)
1033
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1034
+ original_image = np.array(original_image)
1035
+ if input_mask is not None:
1036
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1037
+ input_mask = np.array(input_mask)
1038
+ if original_mask is not None:
1039
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1040
+ original_mask = np.array(original_mask)
1041
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1042
+ else:
1043
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1044
+ pass
1045
+ else:
1046
+ if resize_default:
1047
+ short_side = min(output_w, output_h)
1048
+ scale_ratio = 640 / short_side
1049
+ output_w = int(output_w * scale_ratio)
1050
+ output_h = int(output_h * scale_ratio)
1051
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1052
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1053
+ original_image = np.array(original_image)
1054
+ if input_mask is not None:
1055
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1056
+ input_mask = np.array(input_mask)
1057
+ if original_mask is not None:
1058
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1059
+ original_mask = np.array(original_mask)
1060
+
1061
+ if input_mask.max() == 0:
1062
+ original_mask = original_mask
1063
+ else:
1064
+ original_mask = input_mask
1065
+
1066
+ if original_mask is None:
1067
+ raise gr.Error('Please generate mask first')
1068
+
1069
+ if original_mask.ndim == 2:
1070
+ original_mask = original_mask[:,:,None]
1071
+
1072
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1073
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1074
+
1075
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1076
+ masked_image = masked_image.astype(original_image.dtype)
1077
+ masked_image = Image.fromarray(masked_image)
1078
+
1079
+ if moved_mask.max() <= 1:
1080
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1081
+ original_mask = moved_mask
1082
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1083
+
1084
+
1085
+ def move_mask_right(input_image,
1086
+ original_image,
1087
+ original_mask,
1088
+ moving_pixels,
1089
+ resize_default,
1090
+ aspect_ratio_name):
1091
+ alpha_mask = input_image["layers"][0].split()[3]
1092
+ input_mask = np.asarray(alpha_mask)
1093
+
1094
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1095
+ if output_w == "" or output_h == "":
1096
+ output_h, output_w = original_image.shape[:2]
1097
+ if resize_default:
1098
+ short_side = min(output_w, output_h)
1099
+ scale_ratio = 640 / short_side
1100
+ output_w = int(output_w * scale_ratio)
1101
+ output_h = int(output_h * scale_ratio)
1102
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1103
+ original_image = np.array(original_image)
1104
+ if input_mask is not None:
1105
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1106
+ input_mask = np.array(input_mask)
1107
+ if original_mask is not None:
1108
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1109
+ original_mask = np.array(original_mask)
1110
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1111
+ else:
1112
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1113
+ pass
1114
+ else:
1115
+ if resize_default:
1116
+ short_side = min(output_w, output_h)
1117
+ scale_ratio = 640 / short_side
1118
+ output_w = int(output_w * scale_ratio)
1119
+ output_h = int(output_h * scale_ratio)
1120
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1121
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1122
+ original_image = np.array(original_image)
1123
+ if input_mask is not None:
1124
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1125
+ input_mask = np.array(input_mask)
1126
+ if original_mask is not None:
1127
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1128
+ original_mask = np.array(original_mask)
1129
+
1130
+ if input_mask.max() == 0:
1131
+ original_mask = original_mask
1132
+ else:
1133
+ original_mask = input_mask
1134
+
1135
+ if original_mask is None:
1136
+ raise gr.Error('Please generate mask first')
1137
+
1138
+ if original_mask.ndim == 2:
1139
+ original_mask = original_mask[:,:,None]
1140
+
1141
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1142
+
1143
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1144
+
1145
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1146
+ masked_image = masked_image.astype(original_image.dtype)
1147
+ masked_image = Image.fromarray(masked_image)
1148
+
1149
+
1150
+ if moved_mask.max() <= 1:
1151
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1152
+ original_mask = moved_mask
1153
+
1154
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1155
+
1156
+
1157
+ def move_mask_up(input_image,
1158
+ original_image,
1159
+ original_mask,
1160
+ moving_pixels,
1161
+ resize_default,
1162
+ aspect_ratio_name):
1163
+ alpha_mask = input_image["layers"][0].split()[3]
1164
+ input_mask = np.asarray(alpha_mask)
1165
+
1166
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1167
+ if output_w == "" or output_h == "":
1168
+ output_h, output_w = original_image.shape[:2]
1169
+ if resize_default:
1170
+ short_side = min(output_w, output_h)
1171
+ scale_ratio = 640 / short_side
1172
+ output_w = int(output_w * scale_ratio)
1173
+ output_h = int(output_h * scale_ratio)
1174
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1175
+ original_image = np.array(original_image)
1176
+ if input_mask is not None:
1177
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1178
+ input_mask = np.array(input_mask)
1179
+ if original_mask is not None:
1180
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1181
+ original_mask = np.array(original_mask)
1182
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1183
+ else:
1184
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1185
+ pass
1186
+ else:
1187
+ if resize_default:
1188
+ short_side = min(output_w, output_h)
1189
+ scale_ratio = 640 / short_side
1190
+ output_w = int(output_w * scale_ratio)
1191
+ output_h = int(output_h * scale_ratio)
1192
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1193
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1194
+ original_image = np.array(original_image)
1195
+ if input_mask is not None:
1196
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1197
+ input_mask = np.array(input_mask)
1198
+ if original_mask is not None:
1199
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1200
+ original_mask = np.array(original_mask)
1201
+
1202
+ if input_mask.max() == 0:
1203
+ original_mask = original_mask
1204
+ else:
1205
+ original_mask = input_mask
1206
+
1207
+ if original_mask is None:
1208
+ raise gr.Error('Please generate mask first')
1209
+
1210
+ if original_mask.ndim == 2:
1211
+ original_mask = original_mask[:,:,None]
1212
+
1213
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1214
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1215
+
1216
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1217
+ masked_image = masked_image.astype(original_image.dtype)
1218
+ masked_image = Image.fromarray(masked_image)
1219
+
1220
+ if moved_mask.max() <= 1:
1221
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1222
+ original_mask = moved_mask
1223
+
1224
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1225
+
1226
+
1227
+ def move_mask_down(input_image,
1228
+ original_image,
1229
+ original_mask,
1230
+ moving_pixels,
1231
+ resize_default,
1232
+ aspect_ratio_name):
1233
+ alpha_mask = input_image["layers"][0].split()[3]
1234
+ input_mask = np.asarray(alpha_mask)
1235
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1236
+ if output_w == "" or output_h == "":
1237
+ output_h, output_w = original_image.shape[:2]
1238
+ if resize_default:
1239
+ short_side = min(output_w, output_h)
1240
+ scale_ratio = 640 / short_side
1241
+ output_w = int(output_w * scale_ratio)
1242
+ output_h = int(output_h * scale_ratio)
1243
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1244
+ original_image = np.array(original_image)
1245
+ if input_mask is not None:
1246
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1247
+ input_mask = np.array(input_mask)
1248
+ if original_mask is not None:
1249
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1250
+ original_mask = np.array(original_mask)
1251
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1252
+ else:
1253
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1254
+ pass
1255
+ else:
1256
+ if resize_default:
1257
+ short_side = min(output_w, output_h)
1258
+ scale_ratio = 640 / short_side
1259
+ output_w = int(output_w * scale_ratio)
1260
+ output_h = int(output_h * scale_ratio)
1261
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1262
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1263
+ original_image = np.array(original_image)
1264
+ if input_mask is not None:
1265
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1266
+ input_mask = np.array(input_mask)
1267
+ if original_mask is not None:
1268
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1269
+ original_mask = np.array(original_mask)
1270
+
1271
+ if input_mask.max() == 0:
1272
+ original_mask = original_mask
1273
+ else:
1274
+ original_mask = input_mask
1275
+
1276
+ if original_mask is None:
1277
+ raise gr.Error('Please generate mask first')
1278
+
1279
+ if original_mask.ndim == 2:
1280
+ original_mask = original_mask[:,:,None]
1281
+
1282
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1283
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1284
+
1285
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1286
+ masked_image = masked_image.astype(original_image.dtype)
1287
+ masked_image = Image.fromarray(masked_image)
1288
+
1289
+ if moved_mask.max() <= 1:
1290
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1291
+ original_mask = moved_mask
1292
+
1293
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1294
+
1295
+
1296
+ def invert_mask(input_image,
1297
+ original_image,
1298
+ original_mask,
1299
+ ):
1300
+ alpha_mask = input_image["layers"][0].split()[3]
1301
+ input_mask = np.asarray(alpha_mask)
1302
+ if input_mask.max() == 0:
1303
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1304
+ else:
1305
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1306
+
1307
+ if original_mask is None:
1308
+ raise gr.Error('Please generate mask first')
1309
+
1310
+ original_mask = original_mask.squeeze()
1311
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1312
+
1313
+ if original_mask.ndim == 2:
1314
+ original_mask = original_mask[:,:,None]
1315
+
1316
+ if original_mask.max() <= 1:
1317
+ original_mask = (original_mask * 255).astype(np.uint8)
1318
+
1319
+ masked_image = original_image * (1 - (original_mask>0))
1320
+ masked_image = masked_image.astype(original_image.dtype)
1321
+ masked_image = Image.fromarray(masked_image)
1322
+
1323
+ return [masked_image], [mask_image], original_mask, True
1324
+
1325
+
1326
+ def init_img(base,
1327
+ init_type,
1328
+ prompt,
1329
+ aspect_ratio,
1330
+ example_change_times
1331
+ ):
1332
+ image_pil = base["background"].convert("RGB")
1333
+ original_image = np.array(image_pil)
1334
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1335
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1336
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1337
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1338
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1339
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1340
+ width, height = image_pil.size
1341
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1342
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1343
+ image_pil = image_pil.resize((width_new, height_new))
1344
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1345
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1346
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1347
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1348
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1349
+ else:
1350
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1351
+ aspect_ratio = "Custom resolution"
1352
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1353
+
1354
+
1355
+ def reset_func(input_image,
1356
+ original_image,
1357
+ original_mask,
1358
+ prompt,
1359
+ target_prompt,
1360
+ ):
1361
+ input_image = None
1362
+ original_image = None
1363
+ original_mask = None
1364
+ prompt = ''
1365
+ mask_gallery = []
1366
+ masked_gallery = []
1367
+ result_gallery = []
1368
+ target_prompt = ''
1369
+ if torch.cuda.is_available():
1370
+ torch.cuda.empty_cache()
1371
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1372
+
1373
+
1374
+ def update_example(example_type,
1375
+ prompt,
1376
+ example_change_times):
1377
+ input_image = INPUT_IMAGE_PATH[example_type]
1378
+ image_pil = Image.open(input_image).convert("RGB")
1379
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1380
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1381
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1382
+ width, height = image_pil.size
1383
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1384
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1385
+ image_pil = image_pil.resize((width_new, height_new))
1386
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1387
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1388
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1389
+
1390
+ original_image = np.array(image_pil)
1391
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1392
+ aspect_ratio = "Custom resolution"
1393
+ example_change_times += 1
1394
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1395
+
1396
+
1397
+ def generate_target_prompt(input_image,
1398
+ original_image,
1399
+ prompt):
1400
+ # load example image
1401
+ if isinstance(original_image, str):
1402
+ original_image = input_image
1403
+
1404
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1405
+ vlm_processor,
1406
+ vlm_model,
1407
+ original_image,
1408
+ prompt,
1409
+ device)
1410
+ return prompt_after_apply_instruction
1411
+
1412
+
1413
+ # 新增事件处理函数
1414
+ def generate_blip_description(input_image):
1415
+ if input_image is None:
1416
+ return "", "Input image cannot be None"
1417
+ from app.utils.utils import generate_caption
1418
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
1419
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
1420
+ try:
1421
+ image_pil = input_image["background"].convert("RGB")
1422
+ except KeyError:
1423
+ return "", "Input image missing 'background' key"
1424
+ except AttributeError as e:
1425
+ return "", f"Invalid image object: {str(e)}"
1426
+ try:
1427
+ description = generate_caption(blip_processor, blip_model, image_pil, device)
1428
+ return description, description # 同时更新state和显示组件
1429
+ except Exception as e:
1430
+ return "", f"Caption generation failed: {str(e)}"
1431
+
1432
+
1433
+ def verify_deepseek_api():
1434
+ try:
1435
+ initialize_llm_model()
1436
+ response = llm_model.chat.completions.create(
1437
+ model="deepseek-chat",
1438
+ messages=[
1439
+ {"role": "system", "content": "You are a helpful assistant."},
1440
+ {"role": "user", "content": "Hello."}
1441
+ ]
1442
+ )
1443
+ return True
1444
+ except Exception as e:
1445
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
1446
+
1447
+
1448
+ def llm_response_prompt_after_apply_instruction(image_caption, editing_prompt):
1449
+ try:
1450
+ initialize_llm_model()
1451
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
1452
+ response = llm_model.chat.completions.create(
1453
+ model="deepseek-chat",
1454
+ messages=messages
1455
+ )
1456
+ response_str = response.choices[0].message.content
1457
+ return response_str
1458
+ except Exception as e:
1459
+ raise gr.Error(f"未预期错误: {str(e)},请检查控制台日志获取详细信息")
1460
+
1461
+
1462
+ def enhance_description(prompt, blip_description):
1463
+ try:
1464
+ initialize_llm_model()
1465
+
1466
+ if not prompt or not blip_description:
1467
+ print("Empty input detected")
1468
+ return "", ""
1469
+
1470
+ print(f"Enhancing with prompt: {prompt}")
1471
+ description = llm_response_prompt_after_apply_instruction(blip_description, prompt)
1472
+ return description, description
1473
+
1474
+ except Exception as e:
1475
+ print(f"Enhancement failed: {str(e)}")
1476
+ return "Error occurred", "Error occurred"
1477
+
1478
+
1479
+ block = gr.Blocks(
1480
+ theme=gr.themes.Soft(
1481
+ radius_size=gr.themes.sizes.radius_none,
1482
+ text_size=gr.themes.sizes.text_md
1483
+ )
1484
+ )
1485
+ with block as demo:
1486
+ with gr.Row():
1487
+ with gr.Column():
1488
+ gr.HTML(head)
1489
+
1490
+ gr.Markdown(descriptions)
1491
+
1492
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1493
+ with gr.Row(equal_height=True):
1494
+ gr.Markdown(instructions)
1495
+
1496
+ original_image = gr.State(value=None)
1497
+ original_mask = gr.State(value=None)
1498
+ category = gr.State(value=None)
1499
+ status = gr.State(value=None)
1500
+ invert_mask_state = gr.State(value=False)
1501
+ example_change_times = gr.State(value=0)
1502
+ deepseek_verified = gr.State(value=False)
1503
+ blip_description = gr.State(value="")
1504
+ deepseek_description = gr.State(value="")
1505
+
1506
+
1507
+ with gr.Row():
1508
+ with gr.Column():
1509
+ with gr.Row():
1510
+ input_image = gr.ImageEditor(
1511
+ label="Input Image",
1512
+ type="pil",
1513
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1514
+ layers = False,
1515
+ interactive=True,
1516
+ # height=1024,
1517
+ height=512,
1518
+ sources=["upload"],
1519
+ placeholder="Please click here or the icon below to upload the image.",
1520
+ )
1521
+
1522
+ prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
1523
+ run_button = gr.Button("💫 Run")
1524
+
1525
+ vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1526
+ with gr.Group():
1527
+ with gr.Row():
1528
+ # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1529
+ # GPT4o_KEY = gr.Textbox(type="password", value="sk-d145b963a92649a88843caeb741e8bbc")
1530
+ GPT4o_KEY = gr.Textbox(label="GPT4o API Key", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1531
+ GPT4o_KEY_submit = gr.Button("Submit and Verify")
1532
+
1533
+
1534
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1535
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1536
+
1537
+ with gr.Row():
1538
+ mask_button = gr.Button("Generate Mask")
1539
+ random_mask_button = gr.Button("Square/Circle Mask ")
1540
+
1541
+
1542
+ with gr.Row():
1543
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1544
+
1545
+ target_prompt = gr.Text(
1546
+ label="Input Target Prompt",
1547
+ max_lines=5,
1548
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1549
+ value='',
1550
+ lines=2
1551
+ )
1552
+
1553
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1554
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1555
+ negative_prompt = gr.Text(
1556
+ label="Negative Prompt",
1557
+ max_lines=5,
1558
+ placeholder="Please input your negative prompt",
1559
+ value='ugly, low quality',lines=1
1560
+ )
1561
+
1562
+ control_strength = gr.Slider(
1563
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1564
+ )
1565
+ with gr.Group():
1566
+ seed = gr.Slider(
1567
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1568
+ )
1569
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1570
+
1571
+ blending = gr.Checkbox(label="Blending mode", value=True)
1572
+
1573
+
1574
+ num_samples = gr.Slider(
1575
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1576
+ )
1577
+
1578
+ with gr.Group():
1579
+ with gr.Row():
1580
+ guidance_scale = gr.Slider(
1581
+ label="Guidance scale",
1582
+ minimum=1,
1583
+ maximum=12,
1584
+ step=0.1,
1585
+ value=7.5,
1586
+ )
1587
+ num_inference_steps = gr.Slider(
1588
+ label="Number of inference steps",
1589
+ minimum=1,
1590
+ maximum=50,
1591
+ step=1,
1592
+ value=50,
1593
+ )
1594
+
1595
+
1596
+ with gr.Column():
1597
+ with gr.Group(visible=True):
1598
+ # BLIP生成的描述
1599
+ blip_output = gr.Textbox(label="BLIP生成描述", placeholder="自动生成的图像基础描述...", interactive=False, lines=3)
1600
+ # DeepSeek API验证
1601
+ with gr.Row():
1602
+ deepseek_key = gr.Textbox(label="DeepSeek API Key", value="sk-d145b963a92649a88843caeb741e8bbc", placeholder="输入DeepSeek API密钥进行增强", lines=1)
1603
+ verify_deepseek = gr.Button("Submit and Verify")
1604
+ # 整合后的描述
1605
+ deepseek_output = gr.Textbox(label="整合后描述", placeholder="DeepSeek生成的增强描述...", interactive=True, lines=3)
1606
+
1607
+ with gr.Row():
1608
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1609
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1610
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1611
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1612
+
1613
+ invert_mask_button = gr.Button("Invert Mask")
1614
+ dilation_size = gr.Slider(
1615
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1616
+ )
1617
+ with gr.Row():
1618
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1619
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1620
+
1621
+ moving_pixels = gr.Slider(
1622
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1623
+ )
1624
+ with gr.Row():
1625
+ move_left_button = gr.Button("Move Left")
1626
+ move_right_button = gr.Button("Move Right")
1627
+ with gr.Row():
1628
+ move_up_button = gr.Button("Move Up")
1629
+ move_down_button = gr.Button("Move Down")
1630
+
1631
+ with gr.Tab(elem_classes="feedback", label="Output"):
1632
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1633
+
1634
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1635
+
1636
+ reset_button = gr.Button("Reset")
1637
+
1638
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1639
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1640
+
1641
+
1642
+
1643
+ with gr.Row():
1644
+ example = gr.Examples(
1645
+ label="Quick Example",
1646
+ examples=EXAMPLES,
1647
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1648
+ examples_per_page=10,
1649
+ cache_examples=False,
1650
+ )
1651
+
1652
+
1653
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1654
+ with gr.Row(equal_height=True):
1655
+ gr.Markdown(tips)
1656
+
1657
+ with gr.Row():
1658
+ gr.Markdown(citation)
1659
+
1660
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1661
+ ## And we need to solve the conflict between the upload and change example functions.
1662
+ input_image.upload(
1663
+ init_img,
1664
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1665
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1666
+ )
1667
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1668
+
1669
+ ## vlm and base model dropdown
1670
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1671
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1672
+
1673
+
1674
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1675
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1676
+
1677
+
1678
+ ips=[input_image,
1679
+ original_image,
1680
+ original_mask,
1681
+ prompt,
1682
+ negative_prompt,
1683
+ control_strength,
1684
+ seed,
1685
+ randomize_seed,
1686
+ guidance_scale,
1687
+ num_inference_steps,
1688
+ num_samples,
1689
+ blending,
1690
+ category,
1691
+ target_prompt,
1692
+ resize_default,
1693
+ aspect_ratio,
1694
+ invert_mask_state]
1695
+
1696
+ ## run brushedit
1697
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1698
+
1699
+ ## mask func
1700
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1701
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1702
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1703
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1704
+
1705
+ ## move mask func
1706
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1707
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1708
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1709
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1710
+
1711
+ ## prompt func
1712
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1713
+
1714
+ ## reset func
1715
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1716
+
1717
+
1718
+ # 绑定事件处理
1719
+ input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
1720
+ verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified]).success(fn=enhance_description, inputs=[prompt, blip_description], outputs=[deepseek_description, deepseek_output])
1721
+ # 当BLIP描述更新时自动触发增强(需验证通过)
1722
+ blip_description.change(fn=enhance_description, inputs=[prompt, blip_description], outputs=[deepseek_description, deepseek_output], preprocess=False)
1723
+
1724
+ # if have a localhost access error, try to use the following code
1725
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
brushedit_app_without_clip.py ADDED
@@ -0,0 +1,1758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os, random, sys
4
+ import numpy as np
5
+ import requests
6
+ import torch
7
+
8
+
9
+ import gradio as gr
10
+
11
+ from PIL import Image
12
+
13
+
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from scipy.ndimage import binary_dilation, binary_erosion
16
+ from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
17
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
18
+
19
+ from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
20
+ from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+
24
+ from app.src.vlm_pipeline import (
25
+ vlm_response_editing_type,
26
+ vlm_response_object_wait_for_edit,
27
+ vlm_response_mask,
28
+ vlm_response_prompt_after_apply_instruction
29
+ )
30
+ from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
31
+ from app.utils.utils import load_grounding_dino_model
32
+
33
+ from app.src.vlm_template import vlms_template
34
+ from app.src.base_model_template import base_models_template
35
+ from app.src.aspect_ratio_template import aspect_ratios
36
+
37
+ from openai import OpenAI
38
+ base_openai_url = "https://api.deepseek.com/"
39
+ base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
40
+
41
+
42
+ from transformers import BlipProcessor, BlipForConditionalGeneration
43
+
44
+ from app.deepseek.instructions import (
45
+ create_apply_editing_messages_deepseek,
46
+ create_decomposed_query_messages_deepseek
47
+ )
48
+
49
+
50
+ #### Description ####
51
+ logo = r"""
52
+ <center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
53
+ """
54
+ head = r"""
55
+ <div style="text-align: center;">
56
+ <h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
57
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
58
+ <a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
59
+ <a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
60
+ <a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
61
+
62
+ </div>
63
+ </br>
64
+ </div>
65
+ """
66
+ descriptions = r"""
67
+ Demo for ZS-CIR"""
68
+
69
+ instructions = r"""
70
+ Demo for ZS-CIR"""
71
+
72
+ tips = r"""
73
+ Demo for ZS-CIR
74
+
75
+ """
76
+
77
+
78
+
79
+ citation = r"""
80
+ Demo for ZS-CIR"""
81
+
82
+ # - - - - - examples - - - - - #
83
+ EXAMPLES = [
84
+
85
+ [
86
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
87
+ "add a magic hat on frog head.",
88
+ 642087011,
89
+ "frog",
90
+ "frog",
91
+ True,
92
+ False,
93
+ "GPT4-o (Highly Recommended)"
94
+ ],
95
+ [
96
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
97
+ "replace the background to ancient China.",
98
+ 648464818,
99
+ "chinese_girl",
100
+ "chinese_girl",
101
+ True,
102
+ False,
103
+ "GPT4-o (Highly Recommended)"
104
+ ],
105
+ [
106
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
107
+ "remove the deer.",
108
+ 648464818,
109
+ "angel_christmas",
110
+ "angel_christmas",
111
+ False,
112
+ False,
113
+ "GPT4-o (Highly Recommended)"
114
+ ],
115
+ [
116
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
117
+ "add a wreath on head.",
118
+ 648464818,
119
+ "sunflower_girl",
120
+ "sunflower_girl",
121
+ True,
122
+ False,
123
+ "GPT4-o (Highly Recommended)"
124
+ ],
125
+ [
126
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
127
+ "add a butterfly fairy.",
128
+ 648464818,
129
+ "girl_on_sun",
130
+ "girl_on_sun",
131
+ True,
132
+ False,
133
+ "GPT4-o (Highly Recommended)"
134
+ ],
135
+ [
136
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
137
+ "remove the christmas hat.",
138
+ 642087011,
139
+ "spider_man_rm",
140
+ "spider_man_rm",
141
+ False,
142
+ False,
143
+ "GPT4-o (Highly Recommended)"
144
+ ],
145
+ [
146
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
147
+ "remove the flower.",
148
+ 642087011,
149
+ "anime_flower",
150
+ "anime_flower",
151
+ False,
152
+ False,
153
+ "GPT4-o (Highly Recommended)"
154
+ ],
155
+ [
156
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
157
+ "replace the clothes to a delicated floral skirt.",
158
+ 648464818,
159
+ "chenduling",
160
+ "chenduling",
161
+ True,
162
+ False,
163
+ "GPT4-o (Highly Recommended)"
164
+ ],
165
+ [
166
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
167
+ "make the hedgehog in Italy.",
168
+ 648464818,
169
+ "hedgehog_rp_bg",
170
+ "hedgehog_rp_bg",
171
+ True,
172
+ False,
173
+ "GPT4-o (Highly Recommended)"
174
+ ],
175
+
176
+ ]
177
+
178
+ INPUT_IMAGE_PATH = {
179
+ "frog": "./assets/frog/frog.jpeg",
180
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
181
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
182
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
183
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
184
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
185
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
186
+ "chenduling": "./assets/chenduling/chengduling.jpg",
187
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
188
+ }
189
+ MASK_IMAGE_PATH = {
190
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
191
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
192
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
193
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
194
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
195
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
196
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
197
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
198
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
199
+ }
200
+ MASKED_IMAGE_PATH = {
201
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
202
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
203
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
204
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
205
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
206
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
207
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
208
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
209
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
210
+ }
211
+ OUTPUT_IMAGE_PATH = {
212
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
213
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
214
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
215
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
216
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
217
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
218
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
219
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
220
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
221
+ }
222
+
223
+ # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
224
+ # os.makedirs('gradio_temp_dir', exist_ok=True)
225
+
226
+ VLM_MODEL_NAMES = list(vlms_template.keys())
227
+ DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
228
+
229
+
230
+ BASE_MODELS = list(base_models_template.keys())
231
+ DEFAULT_BASE_MODEL = "realisticVision (Default)"
232
+
233
+ ASPECT_RATIO_LABELS = list(aspect_ratios)
234
+ DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
235
+
236
+
237
+ ## init device
238
+ try:
239
+ if torch.cuda.is_available():
240
+ device = "cuda"
241
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
242
+ device = "mps"
243
+ else:
244
+ device = "cpu"
245
+ except:
246
+ device = "cpu"
247
+
248
+ # ## init torch dtype
249
+ # if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
250
+ # torch_dtype = torch.bfloat16
251
+ # else:
252
+ # torch_dtype = torch.float16
253
+
254
+ # if device == "mps":
255
+ # torch_dtype = torch.float16
256
+
257
+ torch_dtype = torch.float16
258
+
259
+
260
+
261
+ # download hf models
262
+ BrushEdit_path = "models/"
263
+ if not os.path.exists(BrushEdit_path):
264
+ BrushEdit_path = snapshot_download(
265
+ repo_id="TencentARC/BrushEdit",
266
+ local_dir=BrushEdit_path,
267
+ token=os.getenv("HF_TOKEN"),
268
+ )
269
+
270
+ ## init default VLM
271
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
272
+ if vlm_processor != "" and vlm_model != "":
273
+ vlm_model.to(device)
274
+ else:
275
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
276
+
277
+ ## init default LLM
278
+ llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
279
+
280
+ ## init base model
281
+ base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
282
+ brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
283
+ sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
284
+ groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
285
+
286
+
287
+ # input brushnetX ckpt path
288
+ brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
289
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
290
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
291
+ )
292
+ # speed up diffusion process with faster scheduler and memory optimization
293
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
294
+ # remove following line if xformers is not installed or when using Torch 2.0.
295
+ # pipe.enable_xformers_memory_efficient_attention()
296
+ pipe.enable_model_cpu_offload()
297
+
298
+
299
+ ## init SAM
300
+ sam = build_sam(checkpoint=sam_path)
301
+ sam.to(device=device)
302
+ sam_predictor = SamPredictor(sam)
303
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
304
+
305
+ ## init groundingdino_model
306
+ config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
307
+ groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
308
+
309
+ ## Ordinary function
310
+ def crop_and_resize(image: Image.Image,
311
+ target_width: int,
312
+ target_height: int) -> Image.Image:
313
+ """
314
+ Crops and resizes an image while preserving the aspect ratio.
315
+
316
+ Args:
317
+ image (Image.Image): Input PIL image to be cropped and resized.
318
+ target_width (int): Target width of the output image.
319
+ target_height (int): Target height of the output image.
320
+
321
+ Returns:
322
+ Image.Image: Cropped and resized image.
323
+ """
324
+ # Original dimensions
325
+ original_width, original_height = image.size
326
+ original_aspect = original_width / original_height
327
+ target_aspect = target_width / target_height
328
+
329
+ # Calculate crop box to maintain aspect ratio
330
+ if original_aspect > target_aspect:
331
+ # Crop horizontally
332
+ new_width = int(original_height * target_aspect)
333
+ new_height = original_height
334
+ left = (original_width - new_width) / 2
335
+ top = 0
336
+ right = left + new_width
337
+ bottom = original_height
338
+ else:
339
+ # Crop vertically
340
+ new_width = original_width
341
+ new_height = int(original_width / target_aspect)
342
+ left = 0
343
+ top = (original_height - new_height) / 2
344
+ right = original_width
345
+ bottom = top + new_height
346
+
347
+ # Crop and resize
348
+ cropped_image = image.crop((left, top, right, bottom))
349
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
350
+ return resized_image
351
+
352
+
353
+ ## Ordinary function
354
+ def resize(image: Image.Image,
355
+ target_width: int,
356
+ target_height: int) -> Image.Image:
357
+ """
358
+ Crops and resizes an image while preserving the aspect ratio.
359
+
360
+ Args:
361
+ image (Image.Image): Input PIL image to be cropped and resized.
362
+ target_width (int): Target width of the output image.
363
+ target_height (int): Target height of the output image.
364
+
365
+ Returns:
366
+ Image.Image: Cropped and resized image.
367
+ """
368
+ # Original dimensions
369
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
370
+ return resized_image
371
+
372
+
373
+ def move_mask_func(mask, direction, units):
374
+ binary_mask = mask.squeeze()>0
375
+ rows, cols = binary_mask.shape
376
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
377
+
378
+ if direction == 'down':
379
+ # move down
380
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
381
+
382
+ elif direction == 'up':
383
+ # move up
384
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
385
+
386
+ elif direction == 'right':
387
+ # move left
388
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
389
+
390
+ elif direction == 'left':
391
+ # move right
392
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
393
+
394
+ return moved_mask
395
+
396
+
397
+ def random_mask_func(mask, dilation_type='square', dilation_size=20):
398
+ # Randomly select the size of dilation
399
+ binary_mask = mask.squeeze()>0
400
+
401
+ if dilation_type == 'square_dilation':
402
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
403
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
404
+ elif dilation_type == 'square_erosion':
405
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
406
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
407
+ elif dilation_type == 'bounding_box':
408
+ # find the most left top and left bottom point
409
+ rows, cols = np.where(binary_mask)
410
+ if len(rows) == 0 or len(cols) == 0:
411
+ return mask # return original mask if no valid points
412
+
413
+ min_row = np.min(rows)
414
+ max_row = np.max(rows)
415
+ min_col = np.min(cols)
416
+ max_col = np.max(cols)
417
+
418
+ # create a bounding box
419
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
420
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
421
+
422
+ elif dilation_type == 'bounding_ellipse':
423
+ # find the most left top and left bottom point
424
+ rows, cols = np.where(binary_mask)
425
+ if len(rows) == 0 or len(cols) == 0:
426
+ return mask # return original mask if no valid points
427
+
428
+ min_row = np.min(rows)
429
+ max_row = np.max(rows)
430
+ min_col = np.min(cols)
431
+ max_col = np.max(cols)
432
+
433
+ # calculate the center and axis length of the ellipse
434
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
435
+ a = (max_col - min_col) // 2 # half long axis
436
+ b = (max_row - min_row) // 2 # half short axis
437
+
438
+ # create a bounding ellipse
439
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
440
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
441
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
442
+ dilated_mask[ellipse_mask] = True
443
+ else:
444
+ ValueError("dilation_type must be 'square' or 'ellipse'")
445
+
446
+ # use binary dilation
447
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
448
+ return dilated_mask
449
+
450
+
451
+ ## Gradio component function
452
+ def update_vlm_model(vlm_name):
453
+ global vlm_model, vlm_processor
454
+ if vlm_model is not None:
455
+ del vlm_model
456
+ torch.cuda.empty_cache()
457
+
458
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
459
+
460
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
461
+ if vlm_type == "llava-next":
462
+ if vlm_processor != "" and vlm_model != "":
463
+ vlm_model.to(device)
464
+ return vlm_model_dropdown
465
+ else:
466
+ if os.path.exists(vlm_local_path):
467
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
468
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
469
+ else:
470
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
471
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
472
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
473
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
474
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
475
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
476
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
477
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
478
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
479
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
480
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
481
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
482
+ elif vlm_name == "llava-next-72b-hf (Preload)":
483
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
484
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
485
+ elif vlm_type == "qwen2-vl":
486
+ if vlm_processor != "" and vlm_model != "":
487
+ vlm_model.to(device)
488
+ return vlm_model_dropdown
489
+ else:
490
+ if os.path.exists(vlm_local_path):
491
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
492
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
493
+ else:
494
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
495
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
496
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
497
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
498
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
499
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
500
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
501
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
502
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
503
+ elif vlm_type == "openai":
504
+ pass
505
+ return "success"
506
+
507
+
508
+ def update_base_model(base_model_name):
509
+ global pipe
510
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
511
+ if pipe is not None:
512
+ del pipe
513
+ torch.cuda.empty_cache()
514
+ base_model_path, pipe = base_models_template[base_model_name]
515
+ if pipe != "":
516
+ pipe.to(device)
517
+ else:
518
+ if os.path.exists(base_model_path):
519
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
520
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
521
+ )
522
+ # pipe.enable_xformers_memory_efficient_attention()
523
+ pipe.enable_model_cpu_offload()
524
+ else:
525
+ raise gr.Error(f"The base model {base_model_name} does not exist")
526
+ return "success"
527
+
528
+
529
+ def process(input_image,
530
+ original_image,
531
+ original_mask,
532
+ prompt,
533
+ negative_prompt,
534
+ control_strength,
535
+ seed,
536
+ randomize_seed,
537
+ guidance_scale,
538
+ num_inference_steps,
539
+ num_samples,
540
+ blending,
541
+ category,
542
+ target_prompt,
543
+ resize_default,
544
+ aspect_ratio_name,
545
+ invert_mask_state):
546
+ if original_image is None:
547
+ if input_image is None:
548
+ raise gr.Error('Please upload the input image')
549
+ else:
550
+ print("input_image的键:", input_image.keys()) # 打印字典键
551
+ image_pil = input_image["background"].convert("RGB")
552
+ original_image = np.array(image_pil)
553
+ if prompt is None or prompt == "":
554
+ if target_prompt is None or target_prompt == "":
555
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
556
+
557
+ alpha_mask = input_image["layers"][0].split()[3]
558
+ input_mask = np.asarray(alpha_mask)
559
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
560
+ if output_w == "" or output_h == "":
561
+ output_h, output_w = original_image.shape[:2]
562
+
563
+ if resize_default:
564
+ short_side = min(output_w, output_h)
565
+ scale_ratio = 640 / short_side
566
+ output_w = int(output_w * scale_ratio)
567
+ output_h = int(output_h * scale_ratio)
568
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
569
+ original_image = np.array(original_image)
570
+ if input_mask is not None:
571
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
572
+ input_mask = np.array(input_mask)
573
+ if original_mask is not None:
574
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
575
+ original_mask = np.array(original_mask)
576
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
577
+ else:
578
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
579
+ pass
580
+ else:
581
+ if resize_default:
582
+ short_side = min(output_w, output_h)
583
+ scale_ratio = 640 / short_side
584
+ output_w = int(output_w * scale_ratio)
585
+ output_h = int(output_h * scale_ratio)
586
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
587
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
588
+ original_image = np.array(original_image)
589
+ if input_mask is not None:
590
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
591
+ input_mask = np.array(input_mask)
592
+ if original_mask is not None:
593
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
594
+ original_mask = np.array(original_mask)
595
+
596
+ if invert_mask_state:
597
+ original_mask = original_mask
598
+ else:
599
+ if input_mask.max() == 0:
600
+ original_mask = original_mask
601
+ else:
602
+ original_mask = input_mask
603
+
604
+
605
+ ## inpainting directly if target_prompt is not None
606
+ if category is not None:
607
+ pass
608
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
609
+ pass
610
+ else:
611
+ try:
612
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
613
+ except Exception as e:
614
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
615
+
616
+
617
+ if original_mask is not None:
618
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
619
+ else:
620
+ try:
621
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
622
+ vlm_processor,
623
+ vlm_model,
624
+ original_image,
625
+ category,
626
+ prompt,
627
+ device)
628
+
629
+ original_mask = vlm_response_mask(vlm_processor,
630
+ vlm_model,
631
+ category,
632
+ original_image,
633
+ prompt,
634
+ object_wait_for_edit,
635
+ sam,
636
+ sam_predictor,
637
+ sam_automask_generator,
638
+ groundingdino_model,
639
+ device).astype(np.uint8)
640
+ except Exception as e:
641
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
642
+
643
+ if original_mask.ndim == 2:
644
+ original_mask = original_mask[:,:,None]
645
+
646
+
647
+ if target_prompt is not None and len(target_prompt) >= 1:
648
+ prompt_after_apply_instruction = target_prompt
649
+
650
+ else:
651
+ try:
652
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
653
+ vlm_processor,
654
+ vlm_model,
655
+ original_image,
656
+ prompt,
657
+ device)
658
+ except Exception as e:
659
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
660
+
661
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
662
+
663
+
664
+ with torch.autocast(device):
665
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
666
+ prompt_after_apply_instruction,
667
+ original_mask,
668
+ original_image,
669
+ generator,
670
+ num_inference_steps,
671
+ guidance_scale,
672
+ control_strength,
673
+ negative_prompt,
674
+ num_samples,
675
+ blending)
676
+ original_image = np.array(init_image_np)
677
+ masked_image = original_image * (1 - (mask_np>0))
678
+ masked_image = masked_image.astype(np.uint8)
679
+ masked_image = Image.fromarray(masked_image)
680
+ # Save the images (optional)
681
+ # import uuid
682
+ # uuid = str(uuid.uuid4())
683
+ # image[0].save(f"outputs/image_edit_{uuid}_0.png")
684
+ # image[1].save(f"outputs/image_edit_{uuid}_1.png")
685
+ # image[2].save(f"outputs/image_edit_{uuid}_2.png")
686
+ # image[3].save(f"outputs/image_edit_{uuid}_3.png")
687
+ # mask_image.save(f"outputs/mask_{uuid}.png")
688
+ # masked_image.save(f"outputs/masked_image_{uuid}.png")
689
+ gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
690
+ return image, [mask_image], [masked_image], prompt, '', False
691
+
692
+
693
+ def process_mask(input_image,
694
+ original_image,
695
+ prompt,
696
+ resize_default,
697
+ aspect_ratio_name):
698
+ if original_image is None:
699
+ raise gr.Error('Please upload the input image')
700
+ if prompt is None:
701
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
702
+
703
+ ## load mask
704
+ alpha_mask = input_image["layers"][0].split()[3]
705
+ input_mask = np.array(alpha_mask)
706
+
707
+ # load example image
708
+ if isinstance(original_image, str):
709
+ original_image = input_image["background"]
710
+
711
+ if input_mask.max() == 0:
712
+ category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
713
+
714
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
715
+ vlm_model,
716
+ original_image,
717
+ category,
718
+ prompt,
719
+ device)
720
+ # original mask: h,w,1 [0, 255]
721
+ original_mask = vlm_response_mask(
722
+ vlm_processor,
723
+ vlm_model,
724
+ category,
725
+ original_image,
726
+ prompt,
727
+ object_wait_for_edit,
728
+ sam,
729
+ sam_predictor,
730
+ sam_automask_generator,
731
+ groundingdino_model,
732
+ device).astype(np.uint8)
733
+ else:
734
+ original_mask = input_mask.astype(np.uint8)
735
+ category = None
736
+
737
+ ## resize mask if needed
738
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
739
+ if output_w == "" or output_h == "":
740
+ output_h, output_w = original_image.shape[:2]
741
+ if resize_default:
742
+ short_side = min(output_w, output_h)
743
+ scale_ratio = 640 / short_side
744
+ output_w = int(output_w * scale_ratio)
745
+ output_h = int(output_h * scale_ratio)
746
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
747
+ original_image = np.array(original_image)
748
+ if input_mask is not None:
749
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
750
+ input_mask = np.array(input_mask)
751
+ if original_mask is not None:
752
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
753
+ original_mask = np.array(original_mask)
754
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
755
+ else:
756
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
757
+ pass
758
+ else:
759
+ if resize_default:
760
+ short_side = min(output_w, output_h)
761
+ scale_ratio = 640 / short_side
762
+ output_w = int(output_w * scale_ratio)
763
+ output_h = int(output_h * scale_ratio)
764
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
765
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
766
+ original_image = np.array(original_image)
767
+ if input_mask is not None:
768
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
769
+ input_mask = np.array(input_mask)
770
+ if original_mask is not None:
771
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
772
+ original_mask = np.array(original_mask)
773
+
774
+
775
+ if original_mask.ndim == 2:
776
+ original_mask = original_mask[:,:,None]
777
+
778
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
779
+
780
+ masked_image = original_image * (1 - (original_mask>0))
781
+ masked_image = masked_image.astype(np.uint8)
782
+ masked_image = Image.fromarray(masked_image)
783
+
784
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
785
+
786
+
787
+ def process_random_mask(input_image,
788
+ original_image,
789
+ original_mask,
790
+ resize_default,
791
+ aspect_ratio_name,
792
+ ):
793
+
794
+ alpha_mask = input_image["layers"][0].split()[3]
795
+ input_mask = np.asarray(alpha_mask)
796
+
797
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
798
+ if output_w == "" or output_h == "":
799
+ output_h, output_w = original_image.shape[:2]
800
+ if resize_default:
801
+ short_side = min(output_w, output_h)
802
+ scale_ratio = 640 / short_side
803
+ output_w = int(output_w * scale_ratio)
804
+ output_h = int(output_h * scale_ratio)
805
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
806
+ original_image = np.array(original_image)
807
+ if input_mask is not None:
808
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
809
+ input_mask = np.array(input_mask)
810
+ if original_mask is not None:
811
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
812
+ original_mask = np.array(original_mask)
813
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
814
+ else:
815
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
816
+ pass
817
+ else:
818
+ if resize_default:
819
+ short_side = min(output_w, output_h)
820
+ scale_ratio = 640 / short_side
821
+ output_w = int(output_w * scale_ratio)
822
+ output_h = int(output_h * scale_ratio)
823
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
824
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
825
+ original_image = np.array(original_image)
826
+ if input_mask is not None:
827
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
828
+ input_mask = np.array(input_mask)
829
+ if original_mask is not None:
830
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
831
+ original_mask = np.array(original_mask)
832
+
833
+
834
+ if input_mask.max() == 0:
835
+ original_mask = original_mask
836
+ else:
837
+ original_mask = input_mask
838
+
839
+ if original_mask is None:
840
+ raise gr.Error('Please generate mask first')
841
+
842
+ if original_mask.ndim == 2:
843
+ original_mask = original_mask[:,:,None]
844
+
845
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
846
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
847
+
848
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
849
+
850
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
851
+ masked_image = masked_image.astype(original_image.dtype)
852
+ masked_image = Image.fromarray(masked_image)
853
+
854
+
855
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
856
+
857
+
858
+ def process_dilation_mask(input_image,
859
+ original_image,
860
+ original_mask,
861
+ resize_default,
862
+ aspect_ratio_name,
863
+ dilation_size=20):
864
+
865
+ alpha_mask = input_image["layers"][0].split()[3]
866
+ input_mask = np.asarray(alpha_mask)
867
+
868
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
869
+ if output_w == "" or output_h == "":
870
+ output_h, output_w = original_image.shape[:2]
871
+ if resize_default:
872
+ short_side = min(output_w, output_h)
873
+ scale_ratio = 640 / short_side
874
+ output_w = int(output_w * scale_ratio)
875
+ output_h = int(output_h * scale_ratio)
876
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
877
+ original_image = np.array(original_image)
878
+ if input_mask is not None:
879
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
880
+ input_mask = np.array(input_mask)
881
+ if original_mask is not None:
882
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
883
+ original_mask = np.array(original_mask)
884
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
885
+ else:
886
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
887
+ pass
888
+ else:
889
+ if resize_default:
890
+ short_side = min(output_w, output_h)
891
+ scale_ratio = 640 / short_side
892
+ output_w = int(output_w * scale_ratio)
893
+ output_h = int(output_h * scale_ratio)
894
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
895
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
896
+ original_image = np.array(original_image)
897
+ if input_mask is not None:
898
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
899
+ input_mask = np.array(input_mask)
900
+ if original_mask is not None:
901
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
902
+ original_mask = np.array(original_mask)
903
+
904
+ if input_mask.max() == 0:
905
+ original_mask = original_mask
906
+ else:
907
+ original_mask = input_mask
908
+
909
+ if original_mask is None:
910
+ raise gr.Error('Please generate mask first')
911
+
912
+ if original_mask.ndim == 2:
913
+ original_mask = original_mask[:,:,None]
914
+
915
+ dilation_type = np.random.choice(['square_dilation'])
916
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
917
+
918
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
919
+
920
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
921
+ masked_image = masked_image.astype(original_image.dtype)
922
+ masked_image = Image.fromarray(masked_image)
923
+
924
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
925
+
926
+
927
+ def process_erosion_mask(input_image,
928
+ original_image,
929
+ original_mask,
930
+ resize_default,
931
+ aspect_ratio_name,
932
+ dilation_size=20):
933
+ alpha_mask = input_image["layers"][0].split()[3]
934
+ input_mask = np.asarray(alpha_mask)
935
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
936
+ if output_w == "" or output_h == "":
937
+ output_h, output_w = original_image.shape[:2]
938
+ if resize_default:
939
+ short_side = min(output_w, output_h)
940
+ scale_ratio = 640 / short_side
941
+ output_w = int(output_w * scale_ratio)
942
+ output_h = int(output_h * scale_ratio)
943
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
944
+ original_image = np.array(original_image)
945
+ if input_mask is not None:
946
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
947
+ input_mask = np.array(input_mask)
948
+ if original_mask is not None:
949
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
950
+ original_mask = np.array(original_mask)
951
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
952
+ else:
953
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
954
+ pass
955
+ else:
956
+ if resize_default:
957
+ short_side = min(output_w, output_h)
958
+ scale_ratio = 640 / short_side
959
+ output_w = int(output_w * scale_ratio)
960
+ output_h = int(output_h * scale_ratio)
961
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
962
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
963
+ original_image = np.array(original_image)
964
+ if input_mask is not None:
965
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
966
+ input_mask = np.array(input_mask)
967
+ if original_mask is not None:
968
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
969
+ original_mask = np.array(original_mask)
970
+
971
+ if input_mask.max() == 0:
972
+ original_mask = original_mask
973
+ else:
974
+ original_mask = input_mask
975
+
976
+ if original_mask is None:
977
+ raise gr.Error('Please generate mask first')
978
+
979
+ if original_mask.ndim == 2:
980
+ original_mask = original_mask[:,:,None]
981
+
982
+ dilation_type = np.random.choice(['square_erosion'])
983
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
984
+
985
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
986
+
987
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
988
+ masked_image = masked_image.astype(original_image.dtype)
989
+ masked_image = Image.fromarray(masked_image)
990
+
991
+
992
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
993
+
994
+
995
+ def move_mask_left(input_image,
996
+ original_image,
997
+ original_mask,
998
+ moving_pixels,
999
+ resize_default,
1000
+ aspect_ratio_name):
1001
+
1002
+ alpha_mask = input_image["layers"][0].split()[3]
1003
+ input_mask = np.asarray(alpha_mask)
1004
+
1005
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1006
+ if output_w == "" or output_h == "":
1007
+ output_h, output_w = original_image.shape[:2]
1008
+ if resize_default:
1009
+ short_side = min(output_w, output_h)
1010
+ scale_ratio = 640 / short_side
1011
+ output_w = int(output_w * scale_ratio)
1012
+ output_h = int(output_h * scale_ratio)
1013
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1014
+ original_image = np.array(original_image)
1015
+ if input_mask is not None:
1016
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1017
+ input_mask = np.array(input_mask)
1018
+ if original_mask is not None:
1019
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1020
+ original_mask = np.array(original_mask)
1021
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1022
+ else:
1023
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1024
+ pass
1025
+ else:
1026
+ if resize_default:
1027
+ short_side = min(output_w, output_h)
1028
+ scale_ratio = 640 / short_side
1029
+ output_w = int(output_w * scale_ratio)
1030
+ output_h = int(output_h * scale_ratio)
1031
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1032
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1033
+ original_image = np.array(original_image)
1034
+ if input_mask is not None:
1035
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1036
+ input_mask = np.array(input_mask)
1037
+ if original_mask is not None:
1038
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1039
+ original_mask = np.array(original_mask)
1040
+
1041
+ if input_mask.max() == 0:
1042
+ original_mask = original_mask
1043
+ else:
1044
+ original_mask = input_mask
1045
+
1046
+ if original_mask is None:
1047
+ raise gr.Error('Please generate mask first')
1048
+
1049
+ if original_mask.ndim == 2:
1050
+ original_mask = original_mask[:,:,None]
1051
+
1052
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
1053
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1054
+
1055
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1056
+ masked_image = masked_image.astype(original_image.dtype)
1057
+ masked_image = Image.fromarray(masked_image)
1058
+
1059
+ if moved_mask.max() <= 1:
1060
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1061
+ original_mask = moved_mask
1062
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1063
+
1064
+
1065
+ def move_mask_right(input_image,
1066
+ original_image,
1067
+ original_mask,
1068
+ moving_pixels,
1069
+ resize_default,
1070
+ aspect_ratio_name):
1071
+ alpha_mask = input_image["layers"][0].split()[3]
1072
+ input_mask = np.asarray(alpha_mask)
1073
+
1074
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1075
+ if output_w == "" or output_h == "":
1076
+ output_h, output_w = original_image.shape[:2]
1077
+ if resize_default:
1078
+ short_side = min(output_w, output_h)
1079
+ scale_ratio = 640 / short_side
1080
+ output_w = int(output_w * scale_ratio)
1081
+ output_h = int(output_h * scale_ratio)
1082
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1083
+ original_image = np.array(original_image)
1084
+ if input_mask is not None:
1085
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1086
+ input_mask = np.array(input_mask)
1087
+ if original_mask is not None:
1088
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1089
+ original_mask = np.array(original_mask)
1090
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1091
+ else:
1092
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1093
+ pass
1094
+ else:
1095
+ if resize_default:
1096
+ short_side = min(output_w, output_h)
1097
+ scale_ratio = 640 / short_side
1098
+ output_w = int(output_w * scale_ratio)
1099
+ output_h = int(output_h * scale_ratio)
1100
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1101
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1102
+ original_image = np.array(original_image)
1103
+ if input_mask is not None:
1104
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1105
+ input_mask = np.array(input_mask)
1106
+ if original_mask is not None:
1107
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1108
+ original_mask = np.array(original_mask)
1109
+
1110
+ if input_mask.max() == 0:
1111
+ original_mask = original_mask
1112
+ else:
1113
+ original_mask = input_mask
1114
+
1115
+ if original_mask is None:
1116
+ raise gr.Error('Please generate mask first')
1117
+
1118
+ if original_mask.ndim == 2:
1119
+ original_mask = original_mask[:,:,None]
1120
+
1121
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
1122
+
1123
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1124
+
1125
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1126
+ masked_image = masked_image.astype(original_image.dtype)
1127
+ masked_image = Image.fromarray(masked_image)
1128
+
1129
+
1130
+ if moved_mask.max() <= 1:
1131
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1132
+ original_mask = moved_mask
1133
+
1134
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1135
+
1136
+
1137
+ def move_mask_up(input_image,
1138
+ original_image,
1139
+ original_mask,
1140
+ moving_pixels,
1141
+ resize_default,
1142
+ aspect_ratio_name):
1143
+ alpha_mask = input_image["layers"][0].split()[3]
1144
+ input_mask = np.asarray(alpha_mask)
1145
+
1146
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1147
+ if output_w == "" or output_h == "":
1148
+ output_h, output_w = original_image.shape[:2]
1149
+ if resize_default:
1150
+ short_side = min(output_w, output_h)
1151
+ scale_ratio = 640 / short_side
1152
+ output_w = int(output_w * scale_ratio)
1153
+ output_h = int(output_h * scale_ratio)
1154
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1155
+ original_image = np.array(original_image)
1156
+ if input_mask is not None:
1157
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1158
+ input_mask = np.array(input_mask)
1159
+ if original_mask is not None:
1160
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1161
+ original_mask = np.array(original_mask)
1162
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1163
+ else:
1164
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1165
+ pass
1166
+ else:
1167
+ if resize_default:
1168
+ short_side = min(output_w, output_h)
1169
+ scale_ratio = 640 / short_side
1170
+ output_w = int(output_w * scale_ratio)
1171
+ output_h = int(output_h * scale_ratio)
1172
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1173
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1174
+ original_image = np.array(original_image)
1175
+ if input_mask is not None:
1176
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1177
+ input_mask = np.array(input_mask)
1178
+ if original_mask is not None:
1179
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1180
+ original_mask = np.array(original_mask)
1181
+
1182
+ if input_mask.max() == 0:
1183
+ original_mask = original_mask
1184
+ else:
1185
+ original_mask = input_mask
1186
+
1187
+ if original_mask is None:
1188
+ raise gr.Error('Please generate mask first')
1189
+
1190
+ if original_mask.ndim == 2:
1191
+ original_mask = original_mask[:,:,None]
1192
+
1193
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
1194
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1195
+
1196
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1197
+ masked_image = masked_image.astype(original_image.dtype)
1198
+ masked_image = Image.fromarray(masked_image)
1199
+
1200
+ if moved_mask.max() <= 1:
1201
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1202
+ original_mask = moved_mask
1203
+
1204
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1205
+
1206
+
1207
+ def move_mask_down(input_image,
1208
+ original_image,
1209
+ original_mask,
1210
+ moving_pixels,
1211
+ resize_default,
1212
+ aspect_ratio_name):
1213
+ alpha_mask = input_image["layers"][0].split()[3]
1214
+ input_mask = np.asarray(alpha_mask)
1215
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
1216
+ if output_w == "" or output_h == "":
1217
+ output_h, output_w = original_image.shape[:2]
1218
+ if resize_default:
1219
+ short_side = min(output_w, output_h)
1220
+ scale_ratio = 640 / short_side
1221
+ output_w = int(output_w * scale_ratio)
1222
+ output_h = int(output_h * scale_ratio)
1223
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1224
+ original_image = np.array(original_image)
1225
+ if input_mask is not None:
1226
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1227
+ input_mask = np.array(input_mask)
1228
+ if original_mask is not None:
1229
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1230
+ original_mask = np.array(original_mask)
1231
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1232
+ else:
1233
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1234
+ pass
1235
+ else:
1236
+ if resize_default:
1237
+ short_side = min(output_w, output_h)
1238
+ scale_ratio = 640 / short_side
1239
+ output_w = int(output_w * scale_ratio)
1240
+ output_h = int(output_h * scale_ratio)
1241
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
1242
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
1243
+ original_image = np.array(original_image)
1244
+ if input_mask is not None:
1245
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
1246
+ input_mask = np.array(input_mask)
1247
+ if original_mask is not None:
1248
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
1249
+ original_mask = np.array(original_mask)
1250
+
1251
+ if input_mask.max() == 0:
1252
+ original_mask = original_mask
1253
+ else:
1254
+ original_mask = input_mask
1255
+
1256
+ if original_mask is None:
1257
+ raise gr.Error('Please generate mask first')
1258
+
1259
+ if original_mask.ndim == 2:
1260
+ original_mask = original_mask[:,:,None]
1261
+
1262
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
1263
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
1264
+
1265
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
1266
+ masked_image = masked_image.astype(original_image.dtype)
1267
+ masked_image = Image.fromarray(masked_image)
1268
+
1269
+ if moved_mask.max() <= 1:
1270
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
1271
+ original_mask = moved_mask
1272
+
1273
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
1274
+
1275
+
1276
+ def invert_mask(input_image,
1277
+ original_image,
1278
+ original_mask,
1279
+ ):
1280
+ alpha_mask = input_image["layers"][0].split()[3]
1281
+ input_mask = np.asarray(alpha_mask)
1282
+ if input_mask.max() == 0:
1283
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
1284
+ else:
1285
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
1286
+
1287
+ if original_mask is None:
1288
+ raise gr.Error('Please generate mask first')
1289
+
1290
+ original_mask = original_mask.squeeze()
1291
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
1292
+
1293
+ if original_mask.ndim == 2:
1294
+ original_mask = original_mask[:,:,None]
1295
+
1296
+ if original_mask.max() <= 1:
1297
+ original_mask = (original_mask * 255).astype(np.uint8)
1298
+
1299
+ masked_image = original_image * (1 - (original_mask>0))
1300
+ masked_image = masked_image.astype(original_image.dtype)
1301
+ masked_image = Image.fromarray(masked_image)
1302
+
1303
+ return [masked_image], [mask_image], original_mask, True
1304
+
1305
+
1306
+ def init_img(base,
1307
+ init_type,
1308
+ prompt,
1309
+ aspect_ratio,
1310
+ example_change_times
1311
+ ):
1312
+ image_pil = base["background"].convert("RGB")
1313
+ original_image = np.array(image_pil)
1314
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
1315
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
1316
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
1317
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
1318
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
1319
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
1320
+ width, height = image_pil.size
1321
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1322
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1323
+ image_pil = image_pil.resize((width_new, height_new))
1324
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1325
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1326
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1327
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1328
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
1329
+ else:
1330
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
1331
+ aspect_ratio = "Custom resolution"
1332
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
1333
+
1334
+
1335
+ def reset_func(input_image,
1336
+ original_image,
1337
+ original_mask,
1338
+ prompt,
1339
+ target_prompt,
1340
+ ):
1341
+ input_image = None
1342
+ original_image = None
1343
+ original_mask = None
1344
+ prompt = ''
1345
+ mask_gallery = []
1346
+ masked_gallery = []
1347
+ result_gallery = []
1348
+ target_prompt = ''
1349
+ if torch.cuda.is_available():
1350
+ torch.cuda.empty_cache()
1351
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
1352
+
1353
+
1354
+ def update_example(example_type,
1355
+ prompt,
1356
+ example_change_times):
1357
+ input_image = INPUT_IMAGE_PATH[example_type]
1358
+ image_pil = Image.open(input_image).convert("RGB")
1359
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
1360
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
1361
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
1362
+ width, height = image_pil.size
1363
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
1364
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
1365
+ image_pil = image_pil.resize((width_new, height_new))
1366
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
1367
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
1368
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
1369
+
1370
+ original_image = np.array(image_pil)
1371
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
1372
+ aspect_ratio = "Custom resolution"
1373
+ example_change_times += 1
1374
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
1375
+
1376
+
1377
+ def generate_target_prompt(input_image,
1378
+ original_image,
1379
+ prompt):
1380
+ # load example image
1381
+ if isinstance(original_image, str):
1382
+ original_image = input_image
1383
+
1384
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
1385
+ vlm_processor,
1386
+ vlm_model,
1387
+ original_image,
1388
+ prompt,
1389
+ device)
1390
+ return prompt_after_apply_instruction
1391
+
1392
+
1393
+ # 新增事件处理函数
1394
+ def generate_blip_description(input_image):
1395
+ if input_image is None:
1396
+ return "", "Input image cannot be None"
1397
+ from app.utils.utils import generate_caption
1398
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
1399
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
1400
+ try:
1401
+ image_pil = input_image["background"].convert("RGB")
1402
+ except KeyError:
1403
+ return "", "Input image missing 'background' key"
1404
+ except AttributeError as e:
1405
+ return "", f"Invalid image object: {str(e)}"
1406
+ try:
1407
+ description = generate_caption(blip_processor, blip_model, image_pil, device)
1408
+ return description, description # 同时更新state和显示组件
1409
+ except Exception as e:
1410
+ return "", f"Caption generation failed: {str(e)}"
1411
+
1412
+
1413
+ def submit_GPT4o_KEY(GPT4o_KEY):
1414
+ global vlm_model, vlm_processor
1415
+ if vlm_model is not None:
1416
+ del vlm_model
1417
+ torch.cuda.empty_cache()
1418
+ try:
1419
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
1420
+ vlm_processor = ""
1421
+ response = vlm_model.chat.completions.create(
1422
+ model="deepseek-chat",
1423
+ messages=[
1424
+ {"role": "system", "content": "You are a helpful assistant."},
1425
+ {"role": "user", "content": "Hello."}
1426
+ ]
1427
+ )
1428
+ response_str = response.choices[0].message.content
1429
+
1430
+ return "Success. " + response_str, "GPT4-o (Highly Recommended)"
1431
+ except Exception as e:
1432
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
1433
+
1434
+
1435
+
1436
+ def verify_deepseek_api():
1437
+ try:
1438
+ response = llm_model.chat.completions.create(
1439
+ model="deepseek-chat",
1440
+ messages=[
1441
+ {"role": "system", "content": "You are a helpful assistant."},
1442
+ {"role": "user", "content": "Hello."}
1443
+ ]
1444
+ )
1445
+ response_str = response.choices[0].message.content
1446
+
1447
+ return True, "Success. " + response_str
1448
+
1449
+ except Exception as e:
1450
+ return False, "Invalid DeepSeek API Key"
1451
+
1452
+
1453
+ def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
1454
+ try:
1455
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
1456
+ response = llm_model.chat.completions.create(
1457
+ model="deepseek-chat",
1458
+ messages=messages
1459
+ )
1460
+ response_str = response.choices[0].message.content
1461
+ return response_str
1462
+ except Exception as e:
1463
+ raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1464
+
1465
+
1466
+ def llm_decomposed_prompt_after_apply_instruction(integrated_query):
1467
+ try:
1468
+ messages = create_decomposed_query_messages_deepseek(integrated_query)
1469
+ response = llm_model.chat.completions.create(
1470
+ model="deepseek-chat",
1471
+ messages=messages
1472
+ )
1473
+ response_str = response.choices[0].message.content
1474
+ return response_str
1475
+ except Exception as e:
1476
+ raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
1477
+
1478
+
1479
+ def enhance_description(blip_description, prompt):
1480
+ try:
1481
+ if not prompt or not blip_description:
1482
+ print("Empty prompt or blip_description detected")
1483
+ return "", ""
1484
+
1485
+ print(f"Enhancing with prompt: {prompt}")
1486
+ enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
1487
+ return enhanced_description, enhanced_description
1488
+
1489
+ except Exception as e:
1490
+ print(f"Enhancement failed: {str(e)}")
1491
+ return "Error occurred", "Error occurred"
1492
+
1493
+ def decompose_description(enhanced_description):
1494
+ try:
1495
+ if not enhanced_description:
1496
+ print("Empty enhanced_description detected")
1497
+ return "", ""
1498
+
1499
+ print(f"Decomposing the enhanced description: {enhanced_description}")
1500
+ decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
1501
+ return decomposed_description, decomposed_description
1502
+
1503
+ except Exception as e:
1504
+ print(f"Decomposition failed: {str(e)}")
1505
+ return "Error occurred", "Error occurred"
1506
+
1507
+
1508
+ block = gr.Blocks(
1509
+ theme=gr.themes.Soft(
1510
+ radius_size=gr.themes.sizes.radius_none,
1511
+ text_size=gr.themes.sizes.text_md
1512
+ )
1513
+ )
1514
+ with block as demo:
1515
+ with gr.Row():
1516
+ with gr.Column():
1517
+ gr.HTML(head)
1518
+
1519
+ gr.Markdown(descriptions)
1520
+
1521
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
1522
+ with gr.Row(equal_height=True):
1523
+ gr.Markdown(instructions)
1524
+
1525
+ original_image = gr.State(value=None)
1526
+ original_mask = gr.State(value=None)
1527
+ category = gr.State(value=None)
1528
+ status = gr.State(value=None)
1529
+ invert_mask_state = gr.State(value=False)
1530
+ example_change_times = gr.State(value=0)
1531
+ deepseek_verified = gr.State(value=False)
1532
+ blip_description = gr.State(value="")
1533
+ enhanced_description = gr.State(value="")
1534
+ decomposed_description = gr.State(value="")
1535
+
1536
+ with gr.Row():
1537
+ with gr.Column():
1538
+ with gr.Row():
1539
+ input_image = gr.ImageEditor(
1540
+ label="参考图像",
1541
+ type="pil",
1542
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
1543
+ layers = False,
1544
+ interactive=True,
1545
+ # height=1024,
1546
+ height=512,
1547
+ sources=["upload"],
1548
+ placeholder="🫧 点击此处或下面的图标上传图像 🫧",
1549
+ )
1550
+
1551
+ prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
1552
+ run_button = gr.Button("💫 图像编辑")
1553
+
1554
+ vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
1555
+ with gr.Group():
1556
+ with gr.Row():
1557
+ # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
1558
+ GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1559
+ GPT4o_KEY_submit = gr.Button("🙈 验证")
1560
+
1561
+
1562
+ aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
1563
+ resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
1564
+
1565
+ with gr.Row():
1566
+ mask_button = gr.Button("💎 掩膜生成")
1567
+ random_mask_button = gr.Button("Square/Circle Mask ")
1568
+
1569
+
1570
+ with gr.Row():
1571
+ generate_target_prompt_button = gr.Button("Generate Target Prompt")
1572
+
1573
+ target_prompt = gr.Text(
1574
+ label="Input Target Prompt",
1575
+ max_lines=5,
1576
+ placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
1577
+ value='',
1578
+ lines=2
1579
+ )
1580
+
1581
+ with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
1582
+ base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
1583
+ negative_prompt = gr.Text(
1584
+ label="Negative Prompt",
1585
+ max_lines=5,
1586
+ placeholder="Please input your negative prompt",
1587
+ value='ugly, low quality',lines=1
1588
+ )
1589
+
1590
+ control_strength = gr.Slider(
1591
+ label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
1592
+ )
1593
+ with gr.Group():
1594
+ seed = gr.Slider(
1595
+ label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
1596
+ )
1597
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
1598
+
1599
+ blending = gr.Checkbox(label="Blending mode", value=True)
1600
+
1601
+
1602
+ num_samples = gr.Slider(
1603
+ label="Num samples", minimum=0, maximum=4, step=1, value=4
1604
+ )
1605
+
1606
+ with gr.Group():
1607
+ with gr.Row():
1608
+ guidance_scale = gr.Slider(
1609
+ label="Guidance scale",
1610
+ minimum=1,
1611
+ maximum=12,
1612
+ step=0.1,
1613
+ value=7.5,
1614
+ )
1615
+ num_inference_steps = gr.Slider(
1616
+ label="Number of inference steps",
1617
+ minimum=1,
1618
+ maximum=50,
1619
+ step=1,
1620
+ value=50,
1621
+ )
1622
+
1623
+
1624
+ with gr.Group(visible=True):
1625
+ # BLIP生成的描述
1626
+ blip_output = gr.Textbox(label="原图描述", placeholder="💬 BLIP生成的图像基础描述 💬", interactive=True, lines=3)
1627
+ # DeepSeek API验证
1628
+ with gr.Row():
1629
+ deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
1630
+ verify_deepseek = gr.Button("🙈 验证")
1631
+ # 整合后的描述区域
1632
+ with gr.Row():
1633
+ enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3)
1634
+ enhance_button = gr.Button("✨ 整合")
1635
+ # 分解后的描述区域
1636
+ with gr.Row():
1637
+ decomposed_output = gr.Textbox(label="描述分解", placeholder="🔍 DeepSeek生成的分解描述 🔍", interactive=True, lines=3)
1638
+ decompose_button = gr.Button("🔧 分解")
1639
+ with gr.Row():
1640
+ with gr.Tab(elem_classes="feedback", label="Masked Image"):
1641
+ masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
1642
+ with gr.Tab(elem_classes="feedback", label="Mask"):
1643
+ mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
1644
+
1645
+ invert_mask_button = gr.Button("Invert Mask")
1646
+ dilation_size = gr.Slider(
1647
+ label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
1648
+ )
1649
+ with gr.Row():
1650
+ dilation_mask_button = gr.Button("Dilation Generated Mask")
1651
+ erosion_mask_button = gr.Button("Erosion Generated Mask")
1652
+
1653
+ moving_pixels = gr.Slider(
1654
+ label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
1655
+ )
1656
+ with gr.Row():
1657
+ move_left_button = gr.Button("Move Left")
1658
+ move_right_button = gr.Button("Move Right")
1659
+ with gr.Row():
1660
+ move_up_button = gr.Button("Move Up")
1661
+ move_down_button = gr.Button("Move Down")
1662
+
1663
+ with gr.Tab(elem_classes="feedback", label="Output"):
1664
+ result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
1665
+
1666
+ # target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
1667
+
1668
+ reset_button = gr.Button("Reset")
1669
+
1670
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
1671
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
1672
+
1673
+
1674
+
1675
+ with gr.Row():
1676
+ example = gr.Examples(
1677
+ label="Quick Example",
1678
+ examples=EXAMPLES,
1679
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
1680
+ examples_per_page=10,
1681
+ cache_examples=False,
1682
+ )
1683
+
1684
+
1685
+ with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
1686
+ with gr.Row(equal_height=True):
1687
+ gr.Markdown(tips)
1688
+
1689
+ with gr.Row():
1690
+ gr.Markdown(citation)
1691
+
1692
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
1693
+ ## And we need to solve the conflict between the upload and change example functions.
1694
+ input_image.upload(
1695
+ init_img,
1696
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
1697
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
1698
+ )
1699
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
1700
+
1701
+ ## vlm and base model dropdown
1702
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
1703
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
1704
+
1705
+
1706
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
1707
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
1708
+
1709
+
1710
+ ips=[input_image,
1711
+ original_image,
1712
+ original_mask,
1713
+ prompt,
1714
+ negative_prompt,
1715
+ control_strength,
1716
+ seed,
1717
+ randomize_seed,
1718
+ guidance_scale,
1719
+ num_inference_steps,
1720
+ num_samples,
1721
+ blending,
1722
+ category,
1723
+ target_prompt,
1724
+ resize_default,
1725
+ aspect_ratio,
1726
+ invert_mask_state]
1727
+
1728
+ ## run brushedit
1729
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
1730
+
1731
+ ## mask func
1732
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
1733
+ random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1734
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1735
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
1736
+
1737
+ ## move mask func
1738
+ move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1739
+ move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1740
+ move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1741
+ move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
1742
+
1743
+ ## prompt func
1744
+ generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
1745
+
1746
+ ## reset func
1747
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
1748
+
1749
+
1750
+ # 绑定事件处理
1751
+ input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
1752
+ verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
1753
+ enhance_button.click(fn=enhance_description, inputs=[blip_description, prompt], outputs=[enhanced_description, enhanced_output])
1754
+ decompose_button.click(fn=decompose_description, inputs=[enhanced_description], outputs=[decomposed_description, decomposed_output])
1755
+
1756
+ demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
1757
+
1758
+
llm_pipeline.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from openai import OpenAI
3
+ from app.deepseek.instructions import create_apply_editing_messages_deepseek
4
+
5
+
6
+ def run_deepseek_llm_inference(llm_model, messages):
7
+ response = llm_model.chat.completions.create(
8
+ model="deepseek-chat",
9
+ messages=messages
10
+ )
11
+ response_str = response.choices[0].message.content
12
+ return response_str
13
+
14
+
15
+ from openai import AuthenticationError, APIConnectionError, RateLimitError, BadRequestError, APIError
16
+
17
+ def llm_response_prompt_after_apply_instruction(image_caption, editing_prompt):
18
+ try:
19
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
20
+ response_str = run_deepseek_llm_inference(llm_model, messages)
21
+ return response_str
22
+ except AuthenticationError as e:
23
+ raise gr.Error(f"认证失败: 请检查API密钥是否正确 (错误详情: {e.message})")
24
+ except APIConnectionError as e:
25
+ raise gr.Error(f"连接异常: 请检查网络连接后重试 (错误详情: {e.message})")
26
+ except RateLimitError as e:
27
+ raise gr.Error(f"请求超限: 请稍后重试 (错误详情: {e.message})")
28
+ except BadRequestError as e:
29
+ if "model" in e.message.lower():
30
+ raise gr.Error(f"模型错误: 请检查模型名称是否正确 (错误详情: {e.message})")
31
+ raise gr.Error(f"无效请求: 请检查输入参数 (错误详情: {e.message})")
32
+ except APIError as e:
33
+ raise gr.Error(f"API异常: 服务端返回错误 (错误详情: {e.message})")
34
+ except Exception as e:
35
+ raise gr.Error(f"未预期错误: {str(e)},请检查控制台日志获取详细信息")
llm_template.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ from openai import OpenAI
5
+
6
+ ## init device
7
+ device = "cpu"
8
+ torch_dtype = torch.float16
9
+
10
+
11
+ llms_list = [
12
+ {
13
+ "type": "deepseek",
14
+ "name": "deepseek",
15
+ "local_path": "",
16
+ "processor": "",
17
+ "model": ""
18
+ },
19
+ ]
20
+
21
+ llms_template = {k["name"]: (k["type"], k["local_path"], k["processor"], k["model"]) for k in llms_list}
vlm_pipeline.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ import torch
4
+
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+ from openai import OpenAI
11
+ from transformers import (LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration)
12
+
13
+
14
+ from qwen_vl_utils import process_vision_info
15
+
16
+ from app.gpt4_o.instructions import (
17
+ create_editing_category_messages_gpt4o,
18
+ create_ori_object_messages_gpt4o,
19
+ create_add_object_messages_gpt4o,
20
+ create_apply_editing_messages_gpt4o)
21
+
22
+ from app.llava.instructions import (
23
+ create_editing_category_messages_llava,
24
+ create_ori_object_messages_llava,
25
+ create_add_object_messages_llava,
26
+ create_apply_editing_messages_llava)
27
+
28
+ from app.qwen2.instructions import (
29
+ create_editing_category_messages_qwen2,
30
+ create_ori_object_messages_qwen2,
31
+ create_add_object_messages_qwen2,
32
+ create_apply_editing_messages_qwen2)
33
+
34
+ from app.deepseek.instructions import (
35
+ create_editing_category_messages_deepseek,
36
+ create_ori_object_messages_deepseek,
37
+ create_apply_editing_messages_deepseek
38
+ )
39
+
40
+ from app.utils.utils import run_grounded_sam
41
+
42
+
43
+ def encode_image(img):
44
+ img = Image.fromarray(img.astype('uint8'))
45
+ buffered = BytesIO()
46
+ img.save(buffered, format="PNG")
47
+ img_bytes = buffered.getvalue()
48
+ return base64.b64encode(img_bytes).decode('utf-8')
49
+
50
+
51
+ def run_gpt4o_vl_inference(vlm_model,
52
+ messages):
53
+ response = vlm_model.chat.completions.create(
54
+ model="gpt-4o-2024-08-06",
55
+ messages=messages
56
+ )
57
+ response_str = response.choices[0].message.content
58
+ return response_str
59
+
60
+
61
+ def run_deepseek_inference(llm_model,
62
+ messages):
63
+ try:
64
+ response = llm_model.chat.completions.create(
65
+ model="deepseek-chat",
66
+ messages=messages
67
+ )
68
+ response_str = response.choices[0].message.content
69
+ return response_str
70
+ except Exception as e:
71
+ return "Invalid DeepSeek API Key"
72
+
73
+
74
+ def run_llava_next_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
75
+ prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
76
+ inputs = vlm_processor(images=image, text=prompt, return_tensors="pt").to(device)
77
+ output = vlm_model.generate(**inputs, max_new_tokens=200)
78
+ generated_ids_trimmed = [
79
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
80
+ ]
81
+ response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
82
+
83
+ return response_str
84
+
85
+
86
+ def run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
87
+ text = vlm_processor.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+ image_inputs, video_inputs = process_vision_info(messages)
91
+ inputs = vlm_processor(
92
+ text=[text],
93
+ images=image_inputs,
94
+ videos=video_inputs,
95
+ padding=True,
96
+ return_tensors="pt",
97
+ )
98
+ inputs = inputs.to(device)
99
+ generated_ids = vlm_model.generate(**inputs, max_new_tokens=128)
100
+ generated_ids_trimmed = [
101
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
102
+ ]
103
+ response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
104
+ return response_str
105
+
106
+
107
+ ### response editing type
108
+ def vlm_response_editing_type(vlm_processor,
109
+ vlm_model,
110
+ llm_model,
111
+ image,
112
+ image_caption,
113
+ editing_prompt,
114
+ device):
115
+
116
+ if isinstance(vlm_model, OpenAI):
117
+ messages = create_editing_category_messages_gpt4o(editing_prompt)
118
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
119
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
120
+ messages = create_editing_category_messages_llava(editing_prompt)
121
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device=device)
122
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
123
+ # messages = create_editing_category_messages_qwen2(editing_prompt)
124
+ messages = create_editing_category_messages_qwen2(image_caption, editing_prompt)
125
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
126
+ # messages = create_editing_category_messages_deepseek(image_caption, editing_prompt)
127
+ # response_str = run_deepseek_inference(llm_model, messages)
128
+
129
+ try:
130
+ for category_name in ["Addition","Remove","Local","Global","Background"]:
131
+ if category_name.lower() in response_str.lower():
132
+ return category_name
133
+ except Exception as e:
134
+ raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
135
+
136
+
137
+ ### response object to be edited
138
+ def vlm_response_object_wait_for_edit(vlm_processor,
139
+ vlm_model,
140
+ llm_model,
141
+ image,
142
+ image_caption,
143
+ category,
144
+ editing_prompt,
145
+ device):
146
+ if category in ["Background", "Global", "Addition"]:
147
+ edit_object = "nan"
148
+ return edit_object
149
+
150
+ if isinstance(vlm_model, OpenAI):
151
+ messages = create_ori_object_messages_gpt4o(editing_prompt)
152
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
153
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
154
+ messages = create_ori_object_messages_llava(editing_prompt)
155
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
156
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
157
+ # messages = create_ori_object_messages_qwen2(editing_prompt)
158
+ messages = create_ori_object_messages_qwen2(image_caption, editing_prompt)
159
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
160
+ # messages = create_ori_object_messages_deepseek(image_caption, editing_prompt)
161
+ # response_str = run_deepseek_inference(llm_model, messages)
162
+ return response_str
163
+
164
+
165
+ ### response mask
166
+ def vlm_response_mask(vlm_processor,
167
+ vlm_model,
168
+ category,
169
+ image,
170
+ editing_prompt,
171
+ object_wait_for_edit,
172
+ sam=None,
173
+ sam_predictor=None,
174
+ sam_automask_generator=None,
175
+ groundingdino_model=None,
176
+ device=None,
177
+ ):
178
+ mask = None
179
+ if editing_prompt is None or len(editing_prompt)==0:
180
+ raise gr.Error("Please input the editing instruction!")
181
+ height, width = image.shape[:2]
182
+ if category=="Addition":
183
+ try:
184
+ if isinstance(vlm_model, OpenAI):
185
+ base64_image = encode_image(image)
186
+ messages = create_add_object_messages_gpt4o(editing_prompt, base64_image, height=height, width=width)
187
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
188
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
189
+ messages = create_add_object_messages_llava(editing_prompt, height=height, width=width)
190
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
191
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
192
+ base64_image = encode_image(image)
193
+ messages = create_add_object_messages_qwen2(editing_prompt, base64_image, height=height, width=width)
194
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
195
+ pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
196
+ box = re.findall(pattern, response_str)
197
+ box = box[0][1:-1].split(",")
198
+ for i in range(len(box)):
199
+ box[i] = int(box[i])
200
+ cus_mask = np.zeros((height, width))
201
+ cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
202
+ mask = cus_mask
203
+ except:
204
+ raise gr.Error("Please set the mask manually, currently the VLM cannot output the mask!")
205
+
206
+ elif category=="Background":
207
+ labels = "background"
208
+ elif category=="Global":
209
+ mask = 255 * np.zeros((height, width))
210
+ else:
211
+ labels = object_wait_for_edit
212
+
213
+ if mask is None:
214
+ for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
215
+ try:
216
+ detections = run_grounded_sam(
217
+ input_image={"image":Image.fromarray(image.astype('uint8')),
218
+ "mask":None},
219
+ text_prompt=labels,
220
+ task_type="seg",
221
+ box_threshold=thresh,
222
+ text_threshold=0.25,
223
+ # iou_threshold=0.5,
224
+ # scribble_mode="split",
225
+ sam=sam,
226
+ sam_predictor=sam_predictor,
227
+ # sam_automask_generator=sam_automask_generator,
228
+ groundingdino_model=groundingdino_model,
229
+ device=device,
230
+ )
231
+ mask = np.array(detections[0,0,...].cpu()) * 255
232
+ break
233
+ except:
234
+ print(f"wrong in threshhold: {thresh}, continue")
235
+ continue
236
+ return mask
237
+
238
+
239
+ def vlm_response_prompt_after_apply_instruction(vlm_processor,
240
+ vlm_model,
241
+ llm_model,
242
+ image,
243
+ image_caption,
244
+ editing_prompt,
245
+ device):
246
+
247
+ try:
248
+ if isinstance(vlm_model, OpenAI):
249
+ base64_image = encode_image(image)
250
+ messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
251
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
252
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
253
+ messages = create_apply_editing_messages_llava(editing_prompt)
254
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
255
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
256
+ # base64_image = encode_image(image)
257
+ # messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
258
+ messages = create_apply_editing_messages_qwen2(image_caption, editing_prompt)
259
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
260
+ # messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
261
+ # response_str = run_deepseek_inference(llm_model, messages)
262
+ else:
263
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
264
+ except Exception as e:
265
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
266
+ return response_str
vlm_pipeline_noqwen.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ import torch
4
+
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+ from openai import OpenAI
11
+ from transformers import (LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration)
12
+
13
+
14
+ from qwen_vl_utils import process_vision_info
15
+
16
+ from app.gpt4_o.instructions import (
17
+ create_editing_category_messages_gpt4o,
18
+ create_ori_object_messages_gpt4o,
19
+ create_add_object_messages_gpt4o,
20
+ create_apply_editing_messages_gpt4o)
21
+
22
+ from app.llava.instructions import (
23
+ create_editing_category_messages_llava,
24
+ create_ori_object_messages_llava,
25
+ create_add_object_messages_llava,
26
+ create_apply_editing_messages_llava)
27
+
28
+ from app.qwen2.instructions import (
29
+ create_editing_category_messages_qwen2,
30
+ create_ori_object_messages_qwen2,
31
+ create_add_object_messages_qwen2,
32
+ create_apply_editing_messages_qwen2)
33
+
34
+ from app.deepseek.instructions import (
35
+ create_editing_category_messages_deepseek,
36
+ create_ori_object_messages_deepseek,
37
+ create_apply_editing_messages_deepseek
38
+ )
39
+
40
+ from app.utils.utils import run_grounded_sam
41
+
42
+
43
+ def encode_image(img):
44
+ img = Image.fromarray(img.astype('uint8'))
45
+ buffered = BytesIO()
46
+ img.save(buffered, format="PNG")
47
+ img_bytes = buffered.getvalue()
48
+ return base64.b64encode(img_bytes).decode('utf-8')
49
+
50
+
51
+ def run_gpt4o_vl_inference(vlm_model,
52
+ messages):
53
+ response = vlm_model.chat.completions.create(
54
+ model="gpt-4o-2024-08-06",
55
+ messages=messages
56
+ )
57
+ response_str = response.choices[0].message.content
58
+ return response_str
59
+
60
+
61
+ def run_deepseek_inference(llm_model,
62
+ messages):
63
+ try:
64
+ response = llm_model.chat.completions.create(
65
+ model="deepseek-chat",
66
+ messages=messages
67
+ )
68
+ response_str = response.choices[0].message.content
69
+ return response_str
70
+ except Exception as e:
71
+ return "Invalid DeepSeek API Key"
72
+
73
+
74
+ def run_llava_next_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
75
+ prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
76
+ inputs = vlm_processor(images=image, text=prompt, return_tensors="pt").to(device)
77
+ output = vlm_model.generate(**inputs, max_new_tokens=200)
78
+ generated_ids_trimmed = [
79
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
80
+ ]
81
+ response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
82
+
83
+ return response_str
84
+
85
+
86
+ def run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
87
+ text = vlm_processor.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+ image_inputs, video_inputs = process_vision_info(messages)
91
+ inputs = vlm_processor(
92
+ text=[text],
93
+ images=image_inputs,
94
+ videos=video_inputs,
95
+ padding=True,
96
+ return_tensors="pt",
97
+ )
98
+ inputs = inputs.to(device)
99
+ generated_ids = vlm_model.generate(**inputs, max_new_tokens=128)
100
+ generated_ids_trimmed = [
101
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
102
+ ]
103
+ response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
104
+ return response_str
105
+
106
+
107
+ ### response editing type
108
+ def vlm_response_editing_type(vlm_processor,
109
+ vlm_model,
110
+ llm_model,
111
+ image,
112
+ image_caption,
113
+ editing_prompt,
114
+ device):
115
+
116
+ if isinstance(vlm_model, OpenAI):
117
+ messages = create_editing_category_messages_gpt4o(editing_prompt)
118
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
119
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
120
+ messages = create_editing_category_messages_llava(editing_prompt)
121
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device=device)
122
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
123
+ # messages = create_editing_category_messages_qwen2(editing_prompt)
124
+ # response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
125
+ messages = create_editing_category_messages_deepseek(image_caption, editing_prompt)
126
+ response_str = run_deepseek_inference(llm_model, messages)
127
+
128
+ try:
129
+ for category_name in ["Addition","Remove","Local","Global","Background"]:
130
+ if category_name.lower() in response_str.lower():
131
+ return category_name
132
+ except Exception as e:
133
+ raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
134
+
135
+
136
+ ### response object to be edited
137
+ def vlm_response_object_wait_for_edit(vlm_processor,
138
+ vlm_model,
139
+ llm_model,
140
+ image,
141
+ image_caption,
142
+ category,
143
+ editing_prompt,
144
+ device):
145
+ if category in ["Background", "Global", "Addition"]:
146
+ edit_object = "nan"
147
+ return edit_object
148
+
149
+ if isinstance(vlm_model, OpenAI):
150
+ messages = create_ori_object_messages_gpt4o(editing_prompt)
151
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
152
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
153
+ messages = create_ori_object_messages_llava(editing_prompt)
154
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
155
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
156
+ # messages = create_ori_object_messages_qwen2(editing_prompt)
157
+ # response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
158
+ messages = create_ori_object_messages_deepseek(image_caption, editing_prompt)
159
+ response_str = run_deepseek_inference(llm_model, messages)
160
+ return response_str
161
+
162
+
163
+ ### response mask
164
+ def vlm_response_mask(vlm_processor,
165
+ vlm_model,
166
+ category,
167
+ image,
168
+ editing_prompt,
169
+ object_wait_for_edit,
170
+ sam=None,
171
+ sam_predictor=None,
172
+ sam_automask_generator=None,
173
+ groundingdino_model=None,
174
+ device=None,
175
+ ):
176
+ mask = None
177
+ if editing_prompt is None or len(editing_prompt)==0:
178
+ raise gr.Error("Please input the editing instruction!")
179
+ height, width = image.shape[:2]
180
+ if category=="Addition":
181
+ try:
182
+ if isinstance(vlm_model, OpenAI):
183
+ base64_image = encode_image(image)
184
+ messages = create_add_object_messages_gpt4o(editing_prompt, base64_image, height=height, width=width)
185
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
186
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
187
+ messages = create_add_object_messages_llava(editing_prompt, height=height, width=width)
188
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
189
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
190
+ base64_image = encode_image(image)
191
+ messages = create_add_object_messages_qwen2(editing_prompt, base64_image, height=height, width=width)
192
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
193
+ pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
194
+ box = re.findall(pattern, response_str)
195
+ box = box[0][1:-1].split(",")
196
+ for i in range(len(box)):
197
+ box[i] = int(box[i])
198
+ cus_mask = np.zeros((height, width))
199
+ cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
200
+ mask = cus_mask
201
+ except:
202
+ raise gr.Error("Please set the mask manually, currently the VLM cannot output the mask!")
203
+
204
+ elif category=="Background":
205
+ labels = "background"
206
+ elif category=="Global":
207
+ mask = 255 * np.zeros((height, width))
208
+ else:
209
+ labels = object_wait_for_edit
210
+
211
+ if mask is None:
212
+ for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
213
+ try:
214
+ detections = run_grounded_sam(
215
+ input_image={"image":Image.fromarray(image.astype('uint8')),
216
+ "mask":None},
217
+ text_prompt=labels,
218
+ task_type="seg",
219
+ box_threshold=thresh,
220
+ text_threshold=0.25,
221
+ # iou_threshold=0.5,
222
+ # scribble_mode="split",
223
+ sam=sam,
224
+ sam_predictor=sam_predictor,
225
+ # sam_automask_generator=sam_automask_generator,
226
+ groundingdino_model=groundingdino_model,
227
+ device=device,
228
+ )
229
+ mask = np.array(detections[0,0,...].cpu()) * 255
230
+ break
231
+ except:
232
+ print(f"wrong in threshhold: {thresh}, continue")
233
+ continue
234
+ return mask
235
+
236
+
237
+ def vlm_response_prompt_after_apply_instruction(vlm_processor,
238
+ vlm_model,
239
+ llm_model,
240
+ image,
241
+ image_caption,
242
+ editing_prompt,
243
+ device):
244
+
245
+ try:
246
+ if isinstance(vlm_model, OpenAI):
247
+ base64_image = encode_image(image)
248
+ messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
249
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
250
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
251
+ messages = create_apply_editing_messages_llava(editing_prompt)
252
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
253
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
254
+ # base64_image = encode_image(image)
255
+ # messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
256
+ # response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
257
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
258
+ response_str = run_deepseek_inference(llm_model, messages)
259
+ else:
260
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
261
+ except Exception as e:
262
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
263
+ return response_str
vlm_pipeline_old.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ import torch
4
+
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+ from openai import OpenAI
11
+ from transformers import (LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration)
12
+ from qwen_vl_utils import process_vision_info
13
+
14
+ from app.gpt4_o.instructions import (
15
+ create_editing_category_messages_gpt4o,
16
+ create_ori_object_messages_gpt4o,
17
+ create_add_object_messages_gpt4o,
18
+ create_apply_editing_messages_gpt4o)
19
+
20
+ from app.llava.instructions import (
21
+ create_editing_category_messages_llava,
22
+ create_ori_object_messages_llava,
23
+ create_add_object_messages_llava,
24
+ create_apply_editing_messages_llava)
25
+
26
+ from app.qwen2.instructions import (
27
+ create_editing_category_messages_qwen2,
28
+ create_ori_object_messages_qwen2,
29
+ create_add_object_messages_qwen2,
30
+ create_apply_editing_messages_qwen2)
31
+
32
+ from app.utils.utils import run_grounded_sam
33
+
34
+
35
+ def encode_image(img):
36
+ img = Image.fromarray(img.astype('uint8'))
37
+ buffered = BytesIO()
38
+ img.save(buffered, format="PNG")
39
+ img_bytes = buffered.getvalue()
40
+ return base64.b64encode(img_bytes).decode('utf-8')
41
+
42
+
43
+ def run_gpt4o_vl_inference(vlm_model,
44
+ messages):
45
+ response = vlm_model.chat.completions.create(
46
+ model="gpt-4o-2024-08-06",
47
+ messages=messages
48
+ )
49
+ response_str = response.choices[0].message.content
50
+ return response_str
51
+
52
+ def run_llava_next_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
53
+ prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
54
+ inputs = vlm_processor(images=image, text=prompt, return_tensors="pt").to(device)
55
+ output = vlm_model.generate(**inputs, max_new_tokens=200)
56
+ generated_ids_trimmed = [
57
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
58
+ ]
59
+ response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
60
+
61
+ return response_str
62
+
63
+ def run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
64
+ text = vlm_processor.apply_chat_template(
65
+ messages, tokenize=False, add_generation_prompt=True
66
+ )
67
+ image_inputs, video_inputs = process_vision_info(messages)
68
+ inputs = vlm_processor(
69
+ text=[text],
70
+ images=image_inputs,
71
+ videos=video_inputs,
72
+ padding=True,
73
+ return_tensors="pt",
74
+ )
75
+ inputs = inputs.to(device)
76
+ generated_ids = vlm_model.generate(**inputs, max_new_tokens=128)
77
+ generated_ids_trimmed = [
78
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
79
+ ]
80
+ response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
81
+ return response_str
82
+
83
+
84
+ ### response editing type
85
+ def vlm_response_editing_type(vlm_processor,
86
+ vlm_model,
87
+ image,
88
+ editing_prompt,
89
+ device):
90
+
91
+ if isinstance(vlm_model, OpenAI):
92
+ messages = create_editing_category_messages_gpt4o(editing_prompt)
93
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
94
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
95
+ messages = create_editing_category_messages_llava(editing_prompt)
96
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device=device)
97
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
98
+ messages = create_editing_category_messages_qwen2(editing_prompt)
99
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
100
+
101
+ try:
102
+ for category_name in ["Addition","Remove","Local","Global","Background"]:
103
+ if category_name.lower() in response_str.lower():
104
+ return category_name
105
+ except Exception as e:
106
+ raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
107
+
108
+
109
+ ### response object to be edited
110
+ def vlm_response_object_wait_for_edit(vlm_processor,
111
+ vlm_model,
112
+ image,
113
+ category,
114
+ editing_prompt,
115
+ device):
116
+ if category in ["Background", "Global", "Addition"]:
117
+ edit_object = "nan"
118
+ return edit_object
119
+
120
+ if isinstance(vlm_model, OpenAI):
121
+ messages = create_ori_object_messages_gpt4o(editing_prompt)
122
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
123
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
124
+ messages = create_ori_object_messages_llava(editing_prompt)
125
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image , device)
126
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
127
+ messages = create_ori_object_messages_qwen2(editing_prompt)
128
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
129
+ return response_str
130
+
131
+
132
+ ### response mask
133
+ def vlm_response_mask(vlm_processor,
134
+ vlm_model,
135
+ category,
136
+ image,
137
+ editing_prompt,
138
+ object_wait_for_edit,
139
+ sam=None,
140
+ sam_predictor=None,
141
+ sam_automask_generator=None,
142
+ groundingdino_model=None,
143
+ device=None,
144
+ ):
145
+ mask = None
146
+ if editing_prompt is None or len(editing_prompt)==0:
147
+ raise gr.Error("Please input the editing instruction!")
148
+ height, width = image.shape[:2]
149
+ if category=="Addition":
150
+ try:
151
+ if isinstance(vlm_model, OpenAI):
152
+ base64_image = encode_image(image)
153
+ messages = create_add_object_messages_gpt4o(editing_prompt, base64_image, height=height, width=width)
154
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
155
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
156
+ messages = create_add_object_messages_llava(editing_prompt, height=height, width=width)
157
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
158
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
159
+ base64_image = encode_image(image)
160
+ messages = create_add_object_messages_qwen2(editing_prompt, base64_image, height=height, width=width)
161
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
162
+ pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
163
+ box = re.findall(pattern, response_str)
164
+ box = box[0][1:-1].split(",")
165
+ for i in range(len(box)):
166
+ box[i] = int(box[i])
167
+ cus_mask = np.zeros((height, width))
168
+ cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
169
+ mask = cus_mask
170
+ except:
171
+ raise gr.Error("Please set the mask manually, currently the VLM cannot output the mask!")
172
+
173
+ elif category=="Background":
174
+ labels = "background"
175
+ elif category=="Global":
176
+ mask = 255 * np.zeros((height, width))
177
+ else:
178
+ labels = object_wait_for_edit
179
+
180
+ if mask is None:
181
+ for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
182
+ try:
183
+ detections = run_grounded_sam(
184
+ input_image={"image":Image.fromarray(image.astype('uint8')),
185
+ "mask":None},
186
+ text_prompt=labels,
187
+ task_type="seg",
188
+ box_threshold=thresh,
189
+ text_threshold=0.25,
190
+ iou_threshold=0.5,
191
+ scribble_mode="split",
192
+ sam=sam,
193
+ sam_predictor=sam_predictor,
194
+ sam_automask_generator=sam_automask_generator,
195
+ groundingdino_model=groundingdino_model,
196
+ device=device,
197
+ )
198
+ mask = np.array(detections[0,0,...].cpu()) * 255
199
+ break
200
+ except:
201
+ print(f"wrong in threshhold: {thresh}, continue")
202
+ continue
203
+ return mask
204
+
205
+
206
+ def vlm_response_prompt_after_apply_instruction(vlm_processor,
207
+ vlm_model,
208
+ image,
209
+ editing_prompt,
210
+ device):
211
+
212
+ try:
213
+ if isinstance(vlm_model, OpenAI):
214
+ base64_image = encode_image(image)
215
+ messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
216
+ response_str = run_gpt4o_vl_inference(vlm_model, messages)
217
+ elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
218
+ messages = create_apply_editing_messages_llava(editing_prompt)
219
+ response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
220
+ elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
221
+ base64_image = encode_image(image)
222
+ messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
223
+ response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
224
+ else:
225
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
226
+ except Exception as e:
227
+ raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
228
+ return response_str
vlm_template.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ from openai import OpenAI
5
+ from transformers import (
6
+ LlavaNextProcessor, LlavaNextForConditionalGeneration,
7
+ Qwen2VLForConditionalGeneration, Qwen2VLProcessor
8
+ )
9
+ ## init device
10
+ device = "cpu"
11
+ torch_dtype = torch.float16
12
+
13
+
14
+ vlms_list = [
15
+ # {
16
+ # "type": "llava-next",
17
+ # "name": "llava-v1.6-mistral-7b-hf",
18
+ # "local_path": "models/vlms/llava-v1.6-mistral-7b-hf",
19
+ # "processor": LlavaNextProcessor.from_pretrained(
20
+ # "models/vlms/llava-v1.6-mistral-7b-hf"
21
+ # ) if os.path.exists("models/vlms/llava-v1.6-mistral-7b-hf") else LlavaNextProcessor.from_pretrained(
22
+ # "llava-hf/llava-v1.6-mistral-7b-hf"
23
+ # ),
24
+ # "model": LlavaNextForConditionalGeneration.from_pretrained(
25
+ # "models/vlms/llava-v1.6-mistral-7b-hf", torch_dtype=torch_dtype, device_map=device
26
+ # ).to("cpu") if os.path.exists("models/vlms/llava-v1.6-mistral-7b-hf") else
27
+ # LlavaNextForConditionalGeneration.from_pretrained(
28
+ # "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch_dtype, device_map=device
29
+ # ).to("cpu"),
30
+ # },
31
+ {
32
+ "type": "llava-next",
33
+ "name": "llama3-llava-next-8b-hf (Preload)",
34
+ "local_path": "models/vlms/llama3-llava-next-8b-hf",
35
+ "processor": LlavaNextProcessor.from_pretrained(
36
+ "models/vlms/llama3-llava-next-8b-hf"
37
+ ) if os.path.exists("models/vlms/llama3-llava-next-8b-hf") else LlavaNextProcessor.from_pretrained(
38
+ "llava-hf/llama3-llava-next-8b-hf"
39
+ ),
40
+ "model": LlavaNextForConditionalGeneration.from_pretrained(
41
+ "models/vlms/llama3-llava-next-8b-hf", torch_dtype=torch_dtype, device_map=device
42
+ ).to("cpu") if os.path.exists("models/vlms/llama3-llava-next-8b-hf") else
43
+ LlavaNextForConditionalGeneration.from_pretrained(
44
+ "llava-hf/llama3-llava-next-8b-hf", torch_dtype=torch_dtype, device_map=device
45
+ ).to("cpu"),
46
+ },
47
+ # {
48
+ # "type": "llava-next",
49
+ # "name": "llava-v1.6-vicuna-13b-hf",
50
+ # "local_path": "models/vlms/llava-v1.6-vicuna-13b-hf",
51
+ # "processor": LlavaNextProcessor.from_pretrained(
52
+ # "models/vlms/llava-v1.6-vicuna-13b-hf"
53
+ # ) if os.path.exists("models/vlms/llava-v1.6-vicuna-13b-hf") else LlavaNextProcessor.from_pretrained(
54
+ # "llava-hf/llava-v1.6-vicuna-13b-hf"
55
+ # ),
56
+ # "model": LlavaNextForConditionalGeneration.from_pretrained(
57
+ # "models/vlms/llava-v1.6-vicuna-13b-hf", torch_dtype=torch_dtype, device_map=device
58
+ # ).to("cpu") if os.path.exists("models/vlms/llava-v1.6-vicuna-13b-hf") else
59
+ # LlavaNextForConditionalGeneration.from_pretrained(
60
+ # "llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype=torch_dtype, device_map=device
61
+ # ).to("cpu"),
62
+ # },
63
+ # {
64
+ # "type": "llava-next",
65
+ # "name": "llava-v1.6-34b-hf",
66
+ # "local_path": "models/vlms/llava-v1.6-34b-hf",
67
+ # "processor": LlavaNextProcessor.from_pretrained(
68
+ # "models/vlms/llava-v1.6-34b-hf"
69
+ # ) if os.path.exists("models/vlms/llava-v1.6-34b-hf") else LlavaNextProcessor.from_pretrained(
70
+ # "llava-hf/llava-v1.6-34b-hf"
71
+ # ),
72
+ # "model": LlavaNextForConditionalGeneration.from_pretrained(
73
+ # "models/vlms/llava-v1.6-34b-hf", torch_dtype=torch_dtype, device_map=device
74
+ # ).to("cpu") if os.path.exists("models/vlms/llava-v1.6-34b-hf") else
75
+ # LlavaNextForConditionalGeneration.from_pretrained(
76
+ # "llava-hf/llava-v1.6-34b-hf", torch_dtype=torch_dtype, device_map=device
77
+ # ).to("cpu"),
78
+ # },
79
+ # {
80
+ # "type": "qwen2-vl",
81
+ # "name": "Qwen2-VL-2B-Instruct",
82
+ # "local_path": "models/vlms/Qwen2-VL-2B-Instruct",
83
+ # "processor": Qwen2VLProcessor.from_pretrained(
84
+ # "models/vlms/Qwen2-VL-2B-Instruct"
85
+ # ) if os.path.exists("models/vlms/Qwen2-VL-2B-Instruct") else Qwen2VLProcessor.from_pretrained(
86
+ # "Qwen/Qwen2-VL-2B-Instruct"
87
+ # ),
88
+ # "model": Qwen2VLForConditionalGeneration.from_pretrained(
89
+ # "models/vlms/Qwen2-VL-2B-Instruct", torch_dtype=torch_dtype, device_map=device
90
+ # ).to("cpu") if os.path.exists("models/vlms/Qwen2-VL-2B-Instruct") else
91
+ # Qwen2VLForConditionalGeneration.from_pretrained(
92
+ # "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch_dtype, device_map=device
93
+ # ).to("cpu"),
94
+ # },
95
+ {
96
+ "type": "qwen2-vl",
97
+ "name": "Qwen2-VL-7B-Instruct (Default)",
98
+ "local_path": "models/vlms/Qwen2-VL-7B-Instruct",
99
+ "processor": Qwen2VLProcessor.from_pretrained(
100
+ "models/vlms/Qwen2-VL-7B-Instruct"
101
+ ) if os.path.exists("models/vlms/Qwen2-VL-7B-Instruct") else Qwen2VLProcessor.from_pretrained(
102
+ "Qwen/Qwen2-VL-7B-Instruct"
103
+ ),
104
+ "model": Qwen2VLForConditionalGeneration.from_pretrained(
105
+ "models/vlms/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
106
+ ).to("cpu") if os.path.exists("models/vlms/Qwen2-VL-7B-Instruct") else
107
+ Qwen2VLForConditionalGeneration.from_pretrained(
108
+ "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
109
+ ).to("cpu"),
110
+ },
111
+ {
112
+ "type": "openai",
113
+ "name": "GPT4-o (Highly Recommended)",
114
+ "local_path": "",
115
+ "processor": "",
116
+ "model": ""
117
+ },
118
+ ]
119
+
120
+ vlms_template = {k["name"]: (k["type"], k["local_path"], k["processor"], k["model"]) for k in vlms_list}