Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- AnchorIT_CIR_app.py +0 -0
- README.md +3 -9
- __pycache__/aspect_ratio_template.cpython-310.pyc +0 -0
- __pycache__/base_model_template.cpython-310.pyc +0 -0
- __pycache__/brushedit_all_in_one_pipeline.cpython-310.pyc +0 -0
- __pycache__/llm_pipeline.cpython-310.pyc +0 -0
- __pycache__/llm_template.cpython-310.pyc +0 -0
- __pycache__/vlm_pipeline.cpython-310.pyc +0 -0
- __pycache__/vlm_pipeline_noqwen.cpython-310.pyc +0 -0
- __pycache__/vlm_template.cpython-310.pyc +0 -0
- aspect_ratio_template.py +88 -0
- base_model_template.py +61 -0
- brushedit_all_in_one_pipeline.py +73 -0
- brushedit_app.py +1705 -0
- brushedit_app_315_0.py +1696 -0
- brushedit_app_315_1.py +1624 -0
- brushedit_app_315_2.py +1627 -0
- brushedit_app_gradio_new.py +0 -0
- brushedit_app_new.py +0 -0
- brushedit_app_new_0404_cirr_blip1.py +2058 -0
- brushedit_app_new_aftermeeting_nocirr.py +1809 -0
- brushedit_app_new_doable.py +1860 -0
- brushedit_app_new_jietu.py +0 -0
- brushedit_app_new_jietu2.py +0 -0
- brushedit_app_new_notqwen.py +0 -0
- brushedit_app_old.py +1702 -0
- brushedit_app_only_integrate.py +1725 -0
- brushedit_app_without_clip.py +1758 -0
- llm_pipeline.py +35 -0
- llm_template.py +21 -0
- vlm_pipeline.py +266 -0
- vlm_pipeline_noqwen.py +263 -0
- vlm_pipeline_old.py +228 -0
- vlm_template.py +120 -0
AnchorIT_CIR_app.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: AnchorIT_ZS-CIR_BNU
|
3 |
+
app_file: AnchorIT_CIR_app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.44.1
|
|
|
|
|
6 |
---
|
|
|
|
__pycache__/aspect_ratio_template.cpython-310.pyc
ADDED
Binary file (1.16 kB). View file
|
|
__pycache__/base_model_template.cpython-310.pyc
ADDED
Binary file (1.23 kB). View file
|
|
__pycache__/brushedit_all_in_one_pipeline.cpython-310.pyc
ADDED
Binary file (1.73 kB). View file
|
|
__pycache__/llm_pipeline.cpython-310.pyc
ADDED
Binary file (1.77 kB). View file
|
|
__pycache__/llm_template.cpython-310.pyc
ADDED
Binary file (562 Bytes). View file
|
|
__pycache__/vlm_pipeline.cpython-310.pyc
ADDED
Binary file (6.85 kB). View file
|
|
__pycache__/vlm_pipeline_noqwen.cpython-310.pyc
ADDED
Binary file (6.83 kB). View file
|
|
__pycache__/vlm_template.cpython-310.pyc
ADDED
Binary file (1.3 kB). View file
|
|
aspect_ratio_template.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/TencentARC/PhotoMaker/pull/120 written by https://github.com/DiscoNova
|
2 |
+
# Note: Since output width & height need to be divisible by 8, the w & h -values do
|
3 |
+
# not exactly match the stated aspect ratios... but they are "close enough":)
|
4 |
+
|
5 |
+
aspect_ratio_list = [
|
6 |
+
{
|
7 |
+
"name": "Small Square (1:1)",
|
8 |
+
"w": 640,
|
9 |
+
"h": 640,
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"name": "Custom resolution",
|
13 |
+
"w": "",
|
14 |
+
"h": "",
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"name": "Instagram (1:1)",
|
18 |
+
"w": 1024,
|
19 |
+
"h": 1024,
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"name": "35mm film / Landscape (3:2)",
|
23 |
+
"w": 1024,
|
24 |
+
"h": 680,
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "35mm film / Portrait (2:3)",
|
28 |
+
"w": 680,
|
29 |
+
"h": 1024,
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"name": "CRT Monitor / Landscape (4:3)",
|
33 |
+
"w": 1024,
|
34 |
+
"h": 768,
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "CRT Monitor / Portrait (3:4)",
|
38 |
+
"w": 768,
|
39 |
+
"h": 1024,
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"name": "Widescreen TV / Landscape (16:9)",
|
43 |
+
"w": 1024,
|
44 |
+
"h": 576,
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"name": "Widescreen TV / Portrait (9:16)",
|
48 |
+
"w": 576,
|
49 |
+
"h": 1024,
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"name": "Widescreen Monitor / Landscape (16:10)",
|
53 |
+
"w": 1024,
|
54 |
+
"h": 640,
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"name": "Widescreen Monitor / Portrait (10:16)",
|
58 |
+
"w": 640,
|
59 |
+
"h": 1024,
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"name": "Cinemascope (2.39:1)",
|
63 |
+
"w": 1024,
|
64 |
+
"h": 424,
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"name": "Widescreen Movie (1.85:1)",
|
68 |
+
"w": 1024,
|
69 |
+
"h": 552,
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"name": "Academy Movie (1.37:1)",
|
73 |
+
"w": 1024,
|
74 |
+
"h": 744,
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"name": "Sheet-print (A-series) / Landscape (297:210)",
|
78 |
+
"w": 1024,
|
79 |
+
"h": 720,
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"name": "Sheet-print (A-series) / Portrait (210:297)",
|
83 |
+
"w": 720,
|
84 |
+
"h": 1024,
|
85 |
+
},
|
86 |
+
]
|
87 |
+
|
88 |
+
aspect_ratios = {k["name"]: (k["w"], k["h"]) for k in aspect_ratio_list}
|
base_model_template.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
|
5 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
torch_dtype = torch.float16
|
10 |
+
device = "cpu"
|
11 |
+
|
12 |
+
BrushEdit_path = "models/"
|
13 |
+
if not os.path.exists(BrushEdit_path):
|
14 |
+
BrushEdit_path = snapshot_download(
|
15 |
+
repo_id="TencentARC/BrushEdit",
|
16 |
+
local_dir=BrushEdit_path,
|
17 |
+
token=os.getenv("HF_TOKEN"),
|
18 |
+
)
|
19 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
20 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
21 |
+
|
22 |
+
|
23 |
+
base_models_list = [
|
24 |
+
# {
|
25 |
+
# "name": "dreamshaper_8 (Preload)",
|
26 |
+
# "local_path": "models/base_model/dreamshaper_8",
|
27 |
+
# "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
|
28 |
+
# "models/base_model/dreamshaper_8", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
29 |
+
# ).to(device)
|
30 |
+
# },
|
31 |
+
# {
|
32 |
+
# "name": "epicrealism (Preload)",
|
33 |
+
# "local_path": "models/base_model/epicrealism_naturalSinRC1VAE",
|
34 |
+
# "pipe": StableDiffusionBrushNetPipeline.from_pretrained(
|
35 |
+
# "models/base_model/epicrealism_naturalSinRC1VAE", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
36 |
+
# ).to(device)
|
37 |
+
# },
|
38 |
+
{
|
39 |
+
"name": "henmixReal (Preload)",
|
40 |
+
"local_path": "models/base_model/henmixReal_v5c",
|
41 |
+
"pipe": StableDiffusionBrushNetPipeline.from_pretrained(
|
42 |
+
"models/base_model/henmixReal_v5c", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
43 |
+
).to(device)
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"name": "meinamix (Preload)",
|
47 |
+
"local_path": "models/base_model/meinamix_meinaV11",
|
48 |
+
"pipe": StableDiffusionBrushNetPipeline.from_pretrained(
|
49 |
+
"models/base_model/meinamix_meinaV11", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
50 |
+
).to(device)
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"name": "realisticVision (Default)",
|
54 |
+
"local_path": "models/base_model/realisticVisionV60B1_v51VAE",
|
55 |
+
"pipe": StableDiffusionBrushNetPipeline.from_pretrained(
|
56 |
+
"models/base_model/realisticVisionV60B1_v51VAE", brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
57 |
+
).to(device)
|
58 |
+
},
|
59 |
+
]
|
60 |
+
|
61 |
+
base_models_template = {k["name"]: (k["local_path"], k["pipe"]) for k in base_models_list}
|
brushedit_all_in_one_pipeline.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageEnhance
|
2 |
+
from diffusers.image_processor import VaeImageProcessor
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def BrushEdit_Pipeline(pipe,
|
10 |
+
prompts,
|
11 |
+
mask_np,
|
12 |
+
original_image,
|
13 |
+
generator,
|
14 |
+
num_inference_steps,
|
15 |
+
guidance_scale,
|
16 |
+
control_strength,
|
17 |
+
negative_prompt,
|
18 |
+
num_samples,
|
19 |
+
blending):
|
20 |
+
if mask_np.ndim != 3:
|
21 |
+
mask_np = mask_np[:, :, np.newaxis]
|
22 |
+
|
23 |
+
mask_np = mask_np / 255
|
24 |
+
height, width = mask_np.shape[0], mask_np.shape[1]
|
25 |
+
## resize the mask and original image to the same size which is divisible by vae_scale_factor
|
26 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
27 |
+
height_new, width_new = image_processor.get_default_height_width(original_image, height, width)
|
28 |
+
mask_np = cv2.resize(mask_np, (width_new, height_new))[:,:,np.newaxis]
|
29 |
+
mask_blurred = cv2.GaussianBlur(mask_np*255, (21, 21), 0)/255
|
30 |
+
mask_blurred = mask_blurred[:, :, np.newaxis]
|
31 |
+
|
32 |
+
original_image = cv2.resize(original_image, (width_new, height_new))
|
33 |
+
|
34 |
+
init_image = original_image * (1 - mask_np)
|
35 |
+
init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB")
|
36 |
+
mask_image = Image.fromarray((mask_np.repeat(3, -1) * 255).astype(np.uint8)).convert("RGB")
|
37 |
+
|
38 |
+
brushnet_conditioning_scale = float(control_strength)
|
39 |
+
|
40 |
+
images = pipe(
|
41 |
+
[prompts] * num_samples,
|
42 |
+
init_image,
|
43 |
+
mask_image,
|
44 |
+
num_inference_steps=num_inference_steps,
|
45 |
+
guidance_scale=guidance_scale,
|
46 |
+
generator=generator,
|
47 |
+
brushnet_conditioning_scale=brushnet_conditioning_scale,
|
48 |
+
negative_prompt=[negative_prompt]*num_samples,
|
49 |
+
height=height_new,
|
50 |
+
width=width_new,
|
51 |
+
).images
|
52 |
+
## convert to vae shape format, must be divisible by 8
|
53 |
+
original_image_pil = Image.fromarray(original_image).convert("RGB")
|
54 |
+
init_image_np = np.array(image_processor.preprocess(original_image_pil, height=height_new, width=width_new).squeeze())
|
55 |
+
init_image_np = ((init_image_np.transpose(1,2,0) + 1.) / 2.) * 255
|
56 |
+
init_image_np = init_image_np.astype(np.uint8)
|
57 |
+
if blending:
|
58 |
+
mask_blurred = mask_blurred * 0.5 + 0.5
|
59 |
+
image_all = []
|
60 |
+
for image_i in images:
|
61 |
+
image_np = np.array(image_i)
|
62 |
+
## blending
|
63 |
+
image_pasted = init_image_np * (1 - mask_blurred) + mask_blurred * image_np
|
64 |
+
image_pasted = image_pasted.astype(np.uint8)
|
65 |
+
image = Image.fromarray(image_pasted)
|
66 |
+
image_all.append(image)
|
67 |
+
else:
|
68 |
+
image_all = images
|
69 |
+
|
70 |
+
|
71 |
+
return image_all, mask_image, mask_np, init_image_np
|
72 |
+
|
73 |
+
|
brushedit_app.py
ADDED
@@ -0,0 +1,1705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
# base_openai_url = "https://api.deepseek.com/"
|
39 |
+
|
40 |
+
#### Description ####
|
41 |
+
logo = r"""
|
42 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
43 |
+
"""
|
44 |
+
head = r"""
|
45 |
+
<div style="text-align: center;">
|
46 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
47 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
48 |
+
<a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
49 |
+
<a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
50 |
+
<a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
51 |
+
|
52 |
+
</div>
|
53 |
+
</br>
|
54 |
+
</div>
|
55 |
+
"""
|
56 |
+
descriptions = r"""
|
57 |
+
Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
|
58 |
+
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
|
59 |
+
"""
|
60 |
+
|
61 |
+
instructions = r"""
|
62 |
+
Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
|
63 |
+
|
64 |
+
🛠️ <b>Fully automated instruction-based editing</b>:
|
65 |
+
<ul>
|
66 |
+
<li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
|
67 |
+
<li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
|
68 |
+
<li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
|
69 |
+
</ul>
|
70 |
+
|
71 |
+
🛠️ <b>Interactive instruction-based editing</b>:
|
72 |
+
<ul>
|
73 |
+
<li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
|
74 |
+
<li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;"> to restore. </li>
|
75 |
+
<li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
|
76 |
+
<li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
|
77 |
+
</ul>
|
78 |
+
|
79 |
+
<b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
|
80 |
+
|
81 |
+
<b> We recommend zooming out in your browser for a better viewing range and experience. </b>
|
82 |
+
|
83 |
+
<b> For more detailed feature descriptions, see the bottom. </b>
|
84 |
+
|
85 |
+
☕️ Have fun! 🎄 Wishing you a merry Christmas!
|
86 |
+
"""
|
87 |
+
|
88 |
+
tips = r"""
|
89 |
+
💡 <b>Some Tips</b>:
|
90 |
+
<ul>
|
91 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
|
92 |
+
<li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
|
93 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
|
94 |
+
</ul>
|
95 |
+
|
96 |
+
💡 <b>Detailed Features</b>:
|
97 |
+
<ul>
|
98 |
+
<li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
|
99 |
+
<li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
|
100 |
+
<li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
|
101 |
+
<li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
|
102 |
+
<li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
|
103 |
+
<li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
|
104 |
+
<li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
|
105 |
+
<li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
|
106 |
+
<li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
|
107 |
+
<li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
|
108 |
+
<li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
|
109 |
+
</ul>
|
110 |
+
|
111 |
+
💡 <b>Advanced Features</b>:
|
112 |
+
<ul>
|
113 |
+
<li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
|
114 |
+
<li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
|
115 |
+
<li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
|
116 |
+
<li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
|
117 |
+
<li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
|
118 |
+
<li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
|
119 |
+
</ul>
|
120 |
+
|
121 |
+
|
122 |
+
"""
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
citation = r"""
|
127 |
+
If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
|
128 |
+
[](https://github.com/TencentARC/BrushEdit)
|
129 |
+
---
|
130 |
+
📝 **Citation**
|
131 |
+
<br>
|
132 |
+
If our work is useful for your research, please consider citing:
|
133 |
+
```bibtex
|
134 |
+
@misc{li2024brushedit,
|
135 |
+
title={BrushEdit: All-In-One Image Inpainting and Editing},
|
136 |
+
author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
|
137 |
+
year={2024},
|
138 |
+
eprint={2412.10316},
|
139 |
+
archivePrefix={arXiv},
|
140 |
+
primaryClass={cs.CV}
|
141 |
+
}
|
142 |
+
```
|
143 |
+
📧 **Contact**
|
144 |
+
<br>
|
145 |
+
If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# - - - - - examples - - - - - #
|
149 |
+
EXAMPLES = [
|
150 |
+
|
151 |
+
[
|
152 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
153 |
+
"add a magic hat on frog head.",
|
154 |
+
642087011,
|
155 |
+
"frog",
|
156 |
+
"frog",
|
157 |
+
True,
|
158 |
+
False,
|
159 |
+
"GPT4-o (Highly Recommended)"
|
160 |
+
],
|
161 |
+
[
|
162 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
163 |
+
"replace the background to ancient China.",
|
164 |
+
648464818,
|
165 |
+
"chinese_girl",
|
166 |
+
"chinese_girl",
|
167 |
+
True,
|
168 |
+
False,
|
169 |
+
"GPT4-o (Highly Recommended)"
|
170 |
+
],
|
171 |
+
[
|
172 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
173 |
+
"remove the deer.",
|
174 |
+
648464818,
|
175 |
+
"angel_christmas",
|
176 |
+
"angel_christmas",
|
177 |
+
False,
|
178 |
+
False,
|
179 |
+
"GPT4-o (Highly Recommended)"
|
180 |
+
],
|
181 |
+
[
|
182 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
183 |
+
"add a wreath on head.",
|
184 |
+
648464818,
|
185 |
+
"sunflower_girl",
|
186 |
+
"sunflower_girl",
|
187 |
+
True,
|
188 |
+
False,
|
189 |
+
"GPT4-o (Highly Recommended)"
|
190 |
+
],
|
191 |
+
[
|
192 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
193 |
+
"add a butterfly fairy.",
|
194 |
+
648464818,
|
195 |
+
"girl_on_sun",
|
196 |
+
"girl_on_sun",
|
197 |
+
True,
|
198 |
+
False,
|
199 |
+
"GPT4-o (Highly Recommended)"
|
200 |
+
],
|
201 |
+
[
|
202 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
203 |
+
"remove the christmas hat.",
|
204 |
+
642087011,
|
205 |
+
"spider_man_rm",
|
206 |
+
"spider_man_rm",
|
207 |
+
False,
|
208 |
+
False,
|
209 |
+
"GPT4-o (Highly Recommended)"
|
210 |
+
],
|
211 |
+
[
|
212 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
213 |
+
"remove the flower.",
|
214 |
+
642087011,
|
215 |
+
"anime_flower",
|
216 |
+
"anime_flower",
|
217 |
+
False,
|
218 |
+
False,
|
219 |
+
"GPT4-o (Highly Recommended)"
|
220 |
+
],
|
221 |
+
[
|
222 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
223 |
+
"replace the clothes to a delicated floral skirt.",
|
224 |
+
648464818,
|
225 |
+
"chenduling",
|
226 |
+
"chenduling",
|
227 |
+
True,
|
228 |
+
False,
|
229 |
+
"GPT4-o (Highly Recommended)"
|
230 |
+
],
|
231 |
+
[
|
232 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
233 |
+
"make the hedgehog in Italy.",
|
234 |
+
648464818,
|
235 |
+
"hedgehog_rp_bg",
|
236 |
+
"hedgehog_rp_bg",
|
237 |
+
True,
|
238 |
+
False,
|
239 |
+
"GPT4-o (Highly Recommended)"
|
240 |
+
],
|
241 |
+
|
242 |
+
]
|
243 |
+
|
244 |
+
INPUT_IMAGE_PATH = {
|
245 |
+
"frog": "./assets/frog/frog.jpeg",
|
246 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
247 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
248 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
249 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
250 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
251 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
252 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
253 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
254 |
+
}
|
255 |
+
MASK_IMAGE_PATH = {
|
256 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
257 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
258 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
259 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
260 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
261 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
262 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
263 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
264 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
265 |
+
}
|
266 |
+
MASKED_IMAGE_PATH = {
|
267 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
268 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
269 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
270 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
271 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
272 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
273 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
274 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
275 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
276 |
+
}
|
277 |
+
OUTPUT_IMAGE_PATH = {
|
278 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
279 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
280 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
281 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
282 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
283 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
284 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
285 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
286 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
287 |
+
}
|
288 |
+
|
289 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
290 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
291 |
+
|
292 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
293 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
294 |
+
BASE_MODELS = list(base_models_template.keys())
|
295 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
296 |
+
|
297 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
298 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
299 |
+
|
300 |
+
|
301 |
+
## init device
|
302 |
+
try:
|
303 |
+
if torch.cuda.is_available():
|
304 |
+
device = "cuda"
|
305 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
306 |
+
device = "mps"
|
307 |
+
else:
|
308 |
+
device = "cpu"
|
309 |
+
except:
|
310 |
+
device = "cpu"
|
311 |
+
|
312 |
+
# ## init torch dtype
|
313 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
314 |
+
# torch_dtype = torch.bfloat16
|
315 |
+
# else:
|
316 |
+
# torch_dtype = torch.float16
|
317 |
+
|
318 |
+
# if device == "mps":
|
319 |
+
# torch_dtype = torch.float16
|
320 |
+
|
321 |
+
torch_dtype = torch.float16
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
# download hf models
|
326 |
+
BrushEdit_path = "models/"
|
327 |
+
if not os.path.exists(BrushEdit_path):
|
328 |
+
BrushEdit_path = snapshot_download(
|
329 |
+
repo_id="TencentARC/BrushEdit",
|
330 |
+
local_dir=BrushEdit_path,
|
331 |
+
token=os.getenv("HF_TOKEN"),
|
332 |
+
)
|
333 |
+
|
334 |
+
## init default VLM
|
335 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
336 |
+
if vlm_processor != "" and vlm_model != "":
|
337 |
+
vlm_model.to(device)
|
338 |
+
else:
|
339 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
340 |
+
|
341 |
+
|
342 |
+
## init base model
|
343 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
344 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
345 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
346 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
347 |
+
|
348 |
+
|
349 |
+
# input brushnetX ckpt path
|
350 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
351 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
352 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
353 |
+
)
|
354 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
355 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
356 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
357 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
358 |
+
pipe.enable_model_cpu_offload()
|
359 |
+
|
360 |
+
|
361 |
+
## init SAM
|
362 |
+
sam = build_sam(checkpoint=sam_path)
|
363 |
+
sam.to(device=device)
|
364 |
+
sam_predictor = SamPredictor(sam)
|
365 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
366 |
+
|
367 |
+
## init groundingdino_model
|
368 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
369 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
370 |
+
|
371 |
+
## Ordinary function
|
372 |
+
def crop_and_resize(image: Image.Image,
|
373 |
+
target_width: int,
|
374 |
+
target_height: int) -> Image.Image:
|
375 |
+
"""
|
376 |
+
Crops and resizes an image while preserving the aspect ratio.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
380 |
+
target_width (int): Target width of the output image.
|
381 |
+
target_height (int): Target height of the output image.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
Image.Image: Cropped and resized image.
|
385 |
+
"""
|
386 |
+
# Original dimensions
|
387 |
+
original_width, original_height = image.size
|
388 |
+
original_aspect = original_width / original_height
|
389 |
+
target_aspect = target_width / target_height
|
390 |
+
|
391 |
+
# Calculate crop box to maintain aspect ratio
|
392 |
+
if original_aspect > target_aspect:
|
393 |
+
# Crop horizontally
|
394 |
+
new_width = int(original_height * target_aspect)
|
395 |
+
new_height = original_height
|
396 |
+
left = (original_width - new_width) / 2
|
397 |
+
top = 0
|
398 |
+
right = left + new_width
|
399 |
+
bottom = original_height
|
400 |
+
else:
|
401 |
+
# Crop vertically
|
402 |
+
new_width = original_width
|
403 |
+
new_height = int(original_width / target_aspect)
|
404 |
+
left = 0
|
405 |
+
top = (original_height - new_height) / 2
|
406 |
+
right = original_width
|
407 |
+
bottom = top + new_height
|
408 |
+
|
409 |
+
# Crop and resize
|
410 |
+
cropped_image = image.crop((left, top, right, bottom))
|
411 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
412 |
+
return resized_image
|
413 |
+
|
414 |
+
|
415 |
+
## Ordinary function
|
416 |
+
def resize(image: Image.Image,
|
417 |
+
target_width: int,
|
418 |
+
target_height: int) -> Image.Image:
|
419 |
+
"""
|
420 |
+
Crops and resizes an image while preserving the aspect ratio.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
424 |
+
target_width (int): Target width of the output image.
|
425 |
+
target_height (int): Target height of the output image.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
Image.Image: Cropped and resized image.
|
429 |
+
"""
|
430 |
+
# Original dimensions
|
431 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
432 |
+
return resized_image
|
433 |
+
|
434 |
+
|
435 |
+
def move_mask_func(mask, direction, units):
|
436 |
+
binary_mask = mask.squeeze()>0
|
437 |
+
rows, cols = binary_mask.shape
|
438 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
439 |
+
|
440 |
+
if direction == 'down':
|
441 |
+
# move down
|
442 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
443 |
+
|
444 |
+
elif direction == 'up':
|
445 |
+
# move up
|
446 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
447 |
+
|
448 |
+
elif direction == 'right':
|
449 |
+
# move left
|
450 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
451 |
+
|
452 |
+
elif direction == 'left':
|
453 |
+
# move right
|
454 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
455 |
+
|
456 |
+
return moved_mask
|
457 |
+
|
458 |
+
|
459 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
460 |
+
# Randomly select the size of dilation
|
461 |
+
binary_mask = mask.squeeze()>0
|
462 |
+
|
463 |
+
if dilation_type == 'square_dilation':
|
464 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
465 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
466 |
+
elif dilation_type == 'square_erosion':
|
467 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
468 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
469 |
+
elif dilation_type == 'bounding_box':
|
470 |
+
# find the most left top and left bottom point
|
471 |
+
rows, cols = np.where(binary_mask)
|
472 |
+
if len(rows) == 0 or len(cols) == 0:
|
473 |
+
return mask # return original mask if no valid points
|
474 |
+
|
475 |
+
min_row = np.min(rows)
|
476 |
+
max_row = np.max(rows)
|
477 |
+
min_col = np.min(cols)
|
478 |
+
max_col = np.max(cols)
|
479 |
+
|
480 |
+
# create a bounding box
|
481 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
482 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
483 |
+
|
484 |
+
elif dilation_type == 'bounding_ellipse':
|
485 |
+
# find the most left top and left bottom point
|
486 |
+
rows, cols = np.where(binary_mask)
|
487 |
+
if len(rows) == 0 or len(cols) == 0:
|
488 |
+
return mask # return original mask if no valid points
|
489 |
+
|
490 |
+
min_row = np.min(rows)
|
491 |
+
max_row = np.max(rows)
|
492 |
+
min_col = np.min(cols)
|
493 |
+
max_col = np.max(cols)
|
494 |
+
|
495 |
+
# calculate the center and axis length of the ellipse
|
496 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
497 |
+
a = (max_col - min_col) // 2 # half long axis
|
498 |
+
b = (max_row - min_row) // 2 # half short axis
|
499 |
+
|
500 |
+
# create a bounding ellipse
|
501 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
502 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
503 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
504 |
+
dilated_mask[ellipse_mask] = True
|
505 |
+
else:
|
506 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
507 |
+
|
508 |
+
# use binary dilation
|
509 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
510 |
+
return dilated_mask
|
511 |
+
|
512 |
+
|
513 |
+
## Gradio component function
|
514 |
+
def update_vlm_model(vlm_name):
|
515 |
+
global vlm_model, vlm_processor
|
516 |
+
if vlm_model is not None:
|
517 |
+
del vlm_model
|
518 |
+
torch.cuda.empty_cache()
|
519 |
+
|
520 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
521 |
+
|
522 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
523 |
+
if vlm_type == "llava-next":
|
524 |
+
if vlm_processor != "" and vlm_model != "":
|
525 |
+
vlm_model.to(device)
|
526 |
+
return vlm_model_dropdown
|
527 |
+
else:
|
528 |
+
if os.path.exists(vlm_local_path):
|
529 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
530 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
531 |
+
else:
|
532 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
533 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
534 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
535 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
536 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
537 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
538 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
539 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
540 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
541 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
542 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
543 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
544 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
545 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
546 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
547 |
+
elif vlm_type == "qwen2-vl":
|
548 |
+
if vlm_processor != "" and vlm_model != "":
|
549 |
+
vlm_model.to(device)
|
550 |
+
return vlm_model_dropdown
|
551 |
+
else:
|
552 |
+
if os.path.exists(vlm_local_path):
|
553 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
554 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
555 |
+
else:
|
556 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
557 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
558 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
559 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
560 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
561 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
562 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
563 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
564 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
565 |
+
elif vlm_type == "openai":
|
566 |
+
pass
|
567 |
+
return "success"
|
568 |
+
|
569 |
+
|
570 |
+
def update_base_model(base_model_name):
|
571 |
+
global pipe
|
572 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
573 |
+
if pipe is not None:
|
574 |
+
del pipe
|
575 |
+
torch.cuda.empty_cache()
|
576 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
577 |
+
if pipe != "":
|
578 |
+
pipe.to(device)
|
579 |
+
else:
|
580 |
+
if os.path.exists(base_model_path):
|
581 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
582 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
583 |
+
)
|
584 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
585 |
+
pipe.enable_model_cpu_offload()
|
586 |
+
else:
|
587 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
588 |
+
return "success"
|
589 |
+
|
590 |
+
|
591 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
592 |
+
global vlm_model, vlm_processor
|
593 |
+
if vlm_model is not None:
|
594 |
+
del vlm_model
|
595 |
+
torch.cuda.empty_cache()
|
596 |
+
try:
|
597 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
|
598 |
+
# vlm_model = OpenAI(api_key="sk-d145b963a92649a88843caeb741e8bbc", base_url="https://api.deepseek.com")
|
599 |
+
vlm_processor = ""
|
600 |
+
response = vlm_model.chat.completions.create(
|
601 |
+
model="deepseek-chat",
|
602 |
+
messages=[
|
603 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
604 |
+
{"role": "user", "content": "Hello."}
|
605 |
+
]
|
606 |
+
)
|
607 |
+
response_str = response.choices[0].message.content
|
608 |
+
|
609 |
+
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
|
610 |
+
except Exception as e:
|
611 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
612 |
+
|
613 |
+
|
614 |
+
|
615 |
+
def process(input_image,
|
616 |
+
original_image,
|
617 |
+
original_mask,
|
618 |
+
prompt,
|
619 |
+
negative_prompt,
|
620 |
+
control_strength,
|
621 |
+
seed,
|
622 |
+
randomize_seed,
|
623 |
+
guidance_scale,
|
624 |
+
num_inference_steps,
|
625 |
+
num_samples,
|
626 |
+
blending,
|
627 |
+
category,
|
628 |
+
target_prompt,
|
629 |
+
resize_default,
|
630 |
+
aspect_ratio_name,
|
631 |
+
invert_mask_state):
|
632 |
+
if original_image is None:
|
633 |
+
if input_image is None:
|
634 |
+
raise gr.Error('Please upload the input image')
|
635 |
+
else:
|
636 |
+
image_pil = input_image["background"].convert("RGB")
|
637 |
+
original_image = np.array(image_pil)
|
638 |
+
if prompt is None or prompt == "":
|
639 |
+
if target_prompt is None or target_prompt == "":
|
640 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
641 |
+
|
642 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
643 |
+
input_mask = np.asarray(alpha_mask)
|
644 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
645 |
+
if output_w == "" or output_h == "":
|
646 |
+
output_h, output_w = original_image.shape[:2]
|
647 |
+
|
648 |
+
if resize_default:
|
649 |
+
short_side = min(output_w, output_h)
|
650 |
+
scale_ratio = 640 / short_side
|
651 |
+
output_w = int(output_w * scale_ratio)
|
652 |
+
output_h = int(output_h * scale_ratio)
|
653 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
654 |
+
original_image = np.array(original_image)
|
655 |
+
if input_mask is not None:
|
656 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
657 |
+
input_mask = np.array(input_mask)
|
658 |
+
if original_mask is not None:
|
659 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
660 |
+
original_mask = np.array(original_mask)
|
661 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
662 |
+
else:
|
663 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
664 |
+
pass
|
665 |
+
else:
|
666 |
+
if resize_default:
|
667 |
+
short_side = min(output_w, output_h)
|
668 |
+
scale_ratio = 640 / short_side
|
669 |
+
output_w = int(output_w * scale_ratio)
|
670 |
+
output_h = int(output_h * scale_ratio)
|
671 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
672 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
673 |
+
original_image = np.array(original_image)
|
674 |
+
if input_mask is not None:
|
675 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
676 |
+
input_mask = np.array(input_mask)
|
677 |
+
if original_mask is not None:
|
678 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
679 |
+
original_mask = np.array(original_mask)
|
680 |
+
|
681 |
+
if invert_mask_state:
|
682 |
+
original_mask = original_mask
|
683 |
+
else:
|
684 |
+
if input_mask.max() == 0:
|
685 |
+
original_mask = original_mask
|
686 |
+
else:
|
687 |
+
original_mask = input_mask
|
688 |
+
|
689 |
+
|
690 |
+
## inpainting directly if target_prompt is not None
|
691 |
+
if category is not None:
|
692 |
+
pass
|
693 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
694 |
+
pass
|
695 |
+
else:
|
696 |
+
try:
|
697 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
698 |
+
except Exception as e:
|
699 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
700 |
+
|
701 |
+
|
702 |
+
if original_mask is not None:
|
703 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
704 |
+
else:
|
705 |
+
try:
|
706 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
707 |
+
vlm_processor,
|
708 |
+
vlm_model,
|
709 |
+
original_image,
|
710 |
+
category,
|
711 |
+
prompt,
|
712 |
+
device)
|
713 |
+
|
714 |
+
original_mask = vlm_response_mask(vlm_processor,
|
715 |
+
vlm_model,
|
716 |
+
category,
|
717 |
+
original_image,
|
718 |
+
prompt,
|
719 |
+
object_wait_for_edit,
|
720 |
+
sam,
|
721 |
+
sam_predictor,
|
722 |
+
sam_automask_generator,
|
723 |
+
groundingdino_model,
|
724 |
+
device).astype(np.uint8)
|
725 |
+
except Exception as e:
|
726 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
727 |
+
|
728 |
+
if original_mask.ndim == 2:
|
729 |
+
original_mask = original_mask[:,:,None]
|
730 |
+
|
731 |
+
|
732 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
733 |
+
prompt_after_apply_instruction = target_prompt
|
734 |
+
|
735 |
+
else:
|
736 |
+
try:
|
737 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
738 |
+
vlm_processor,
|
739 |
+
vlm_model,
|
740 |
+
original_image,
|
741 |
+
prompt,
|
742 |
+
device)
|
743 |
+
except Exception as e:
|
744 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
745 |
+
|
746 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
747 |
+
|
748 |
+
|
749 |
+
with torch.autocast(device):
|
750 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
751 |
+
prompt_after_apply_instruction,
|
752 |
+
original_mask,
|
753 |
+
original_image,
|
754 |
+
generator,
|
755 |
+
num_inference_steps,
|
756 |
+
guidance_scale,
|
757 |
+
control_strength,
|
758 |
+
negative_prompt,
|
759 |
+
num_samples,
|
760 |
+
blending)
|
761 |
+
original_image = np.array(init_image_np)
|
762 |
+
masked_image = original_image * (1 - (mask_np>0))
|
763 |
+
masked_image = masked_image.astype(np.uint8)
|
764 |
+
masked_image = Image.fromarray(masked_image)
|
765 |
+
# Save the images (optional)
|
766 |
+
# import uuid
|
767 |
+
# uuid = str(uuid.uuid4())
|
768 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
769 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
770 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
771 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
772 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
773 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
774 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
775 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
776 |
+
|
777 |
+
|
778 |
+
def generate_target_prompt(input_image,
|
779 |
+
original_image,
|
780 |
+
prompt):
|
781 |
+
# load example image
|
782 |
+
if isinstance(original_image, str):
|
783 |
+
original_image = input_image
|
784 |
+
|
785 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
786 |
+
vlm_processor,
|
787 |
+
vlm_model,
|
788 |
+
original_image,
|
789 |
+
prompt,
|
790 |
+
device)
|
791 |
+
return prompt_after_apply_instruction
|
792 |
+
|
793 |
+
|
794 |
+
def process_mask(input_image,
|
795 |
+
original_image,
|
796 |
+
prompt,
|
797 |
+
resize_default,
|
798 |
+
aspect_ratio_name):
|
799 |
+
if original_image is None:
|
800 |
+
raise gr.Error('Please upload the input image')
|
801 |
+
if prompt is None:
|
802 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
803 |
+
|
804 |
+
## load mask
|
805 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
806 |
+
input_mask = np.array(alpha_mask)
|
807 |
+
|
808 |
+
# load example image
|
809 |
+
if isinstance(original_image, str):
|
810 |
+
original_image = input_image["background"]
|
811 |
+
|
812 |
+
if input_mask.max() == 0:
|
813 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
814 |
+
|
815 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
816 |
+
vlm_model,
|
817 |
+
original_image,
|
818 |
+
category,
|
819 |
+
prompt,
|
820 |
+
device)
|
821 |
+
# original mask: h,w,1 [0, 255]
|
822 |
+
original_mask = vlm_response_mask(
|
823 |
+
vlm_processor,
|
824 |
+
vlm_model,
|
825 |
+
category,
|
826 |
+
original_image,
|
827 |
+
prompt,
|
828 |
+
object_wait_for_edit,
|
829 |
+
sam,
|
830 |
+
sam_predictor,
|
831 |
+
sam_automask_generator,
|
832 |
+
groundingdino_model,
|
833 |
+
device).astype(np.uint8)
|
834 |
+
else:
|
835 |
+
original_mask = input_mask.astype(np.uint8)
|
836 |
+
category = None
|
837 |
+
|
838 |
+
## resize mask if needed
|
839 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
840 |
+
if output_w == "" or output_h == "":
|
841 |
+
output_h, output_w = original_image.shape[:2]
|
842 |
+
if resize_default:
|
843 |
+
short_side = min(output_w, output_h)
|
844 |
+
scale_ratio = 640 / short_side
|
845 |
+
output_w = int(output_w * scale_ratio)
|
846 |
+
output_h = int(output_h * scale_ratio)
|
847 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
848 |
+
original_image = np.array(original_image)
|
849 |
+
if input_mask is not None:
|
850 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
851 |
+
input_mask = np.array(input_mask)
|
852 |
+
if original_mask is not None:
|
853 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
854 |
+
original_mask = np.array(original_mask)
|
855 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
856 |
+
else:
|
857 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
858 |
+
pass
|
859 |
+
else:
|
860 |
+
if resize_default:
|
861 |
+
short_side = min(output_w, output_h)
|
862 |
+
scale_ratio = 640 / short_side
|
863 |
+
output_w = int(output_w * scale_ratio)
|
864 |
+
output_h = int(output_h * scale_ratio)
|
865 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
866 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
867 |
+
original_image = np.array(original_image)
|
868 |
+
if input_mask is not None:
|
869 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
870 |
+
input_mask = np.array(input_mask)
|
871 |
+
if original_mask is not None:
|
872 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
873 |
+
original_mask = np.array(original_mask)
|
874 |
+
|
875 |
+
|
876 |
+
if original_mask.ndim == 2:
|
877 |
+
original_mask = original_mask[:,:,None]
|
878 |
+
|
879 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
880 |
+
|
881 |
+
masked_image = original_image * (1 - (original_mask>0))
|
882 |
+
masked_image = masked_image.astype(np.uint8)
|
883 |
+
masked_image = Image.fromarray(masked_image)
|
884 |
+
|
885 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
886 |
+
|
887 |
+
|
888 |
+
def process_random_mask(input_image,
|
889 |
+
original_image,
|
890 |
+
original_mask,
|
891 |
+
resize_default,
|
892 |
+
aspect_ratio_name,
|
893 |
+
):
|
894 |
+
|
895 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
896 |
+
input_mask = np.asarray(alpha_mask)
|
897 |
+
|
898 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
899 |
+
if output_w == "" or output_h == "":
|
900 |
+
output_h, output_w = original_image.shape[:2]
|
901 |
+
if resize_default:
|
902 |
+
short_side = min(output_w, output_h)
|
903 |
+
scale_ratio = 640 / short_side
|
904 |
+
output_w = int(output_w * scale_ratio)
|
905 |
+
output_h = int(output_h * scale_ratio)
|
906 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
907 |
+
original_image = np.array(original_image)
|
908 |
+
if input_mask is not None:
|
909 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
910 |
+
input_mask = np.array(input_mask)
|
911 |
+
if original_mask is not None:
|
912 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
913 |
+
original_mask = np.array(original_mask)
|
914 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
915 |
+
else:
|
916 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
917 |
+
pass
|
918 |
+
else:
|
919 |
+
if resize_default:
|
920 |
+
short_side = min(output_w, output_h)
|
921 |
+
scale_ratio = 640 / short_side
|
922 |
+
output_w = int(output_w * scale_ratio)
|
923 |
+
output_h = int(output_h * scale_ratio)
|
924 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
925 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
926 |
+
original_image = np.array(original_image)
|
927 |
+
if input_mask is not None:
|
928 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
929 |
+
input_mask = np.array(input_mask)
|
930 |
+
if original_mask is not None:
|
931 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
932 |
+
original_mask = np.array(original_mask)
|
933 |
+
|
934 |
+
|
935 |
+
if input_mask.max() == 0:
|
936 |
+
original_mask = original_mask
|
937 |
+
else:
|
938 |
+
original_mask = input_mask
|
939 |
+
|
940 |
+
if original_mask is None:
|
941 |
+
raise gr.Error('Please generate mask first')
|
942 |
+
|
943 |
+
if original_mask.ndim == 2:
|
944 |
+
original_mask = original_mask[:,:,None]
|
945 |
+
|
946 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
947 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
948 |
+
|
949 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
950 |
+
|
951 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
952 |
+
masked_image = masked_image.astype(original_image.dtype)
|
953 |
+
masked_image = Image.fromarray(masked_image)
|
954 |
+
|
955 |
+
|
956 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
957 |
+
|
958 |
+
|
959 |
+
def process_dilation_mask(input_image,
|
960 |
+
original_image,
|
961 |
+
original_mask,
|
962 |
+
resize_default,
|
963 |
+
aspect_ratio_name,
|
964 |
+
dilation_size=20):
|
965 |
+
|
966 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
967 |
+
input_mask = np.asarray(alpha_mask)
|
968 |
+
|
969 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
970 |
+
if output_w == "" or output_h == "":
|
971 |
+
output_h, output_w = original_image.shape[:2]
|
972 |
+
if resize_default:
|
973 |
+
short_side = min(output_w, output_h)
|
974 |
+
scale_ratio = 640 / short_side
|
975 |
+
output_w = int(output_w * scale_ratio)
|
976 |
+
output_h = int(output_h * scale_ratio)
|
977 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
978 |
+
original_image = np.array(original_image)
|
979 |
+
if input_mask is not None:
|
980 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
981 |
+
input_mask = np.array(input_mask)
|
982 |
+
if original_mask is not None:
|
983 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
984 |
+
original_mask = np.array(original_mask)
|
985 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
986 |
+
else:
|
987 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
988 |
+
pass
|
989 |
+
else:
|
990 |
+
if resize_default:
|
991 |
+
short_side = min(output_w, output_h)
|
992 |
+
scale_ratio = 640 / short_side
|
993 |
+
output_w = int(output_w * scale_ratio)
|
994 |
+
output_h = int(output_h * scale_ratio)
|
995 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
996 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
997 |
+
original_image = np.array(original_image)
|
998 |
+
if input_mask is not None:
|
999 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1000 |
+
input_mask = np.array(input_mask)
|
1001 |
+
if original_mask is not None:
|
1002 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1003 |
+
original_mask = np.array(original_mask)
|
1004 |
+
|
1005 |
+
if input_mask.max() == 0:
|
1006 |
+
original_mask = original_mask
|
1007 |
+
else:
|
1008 |
+
original_mask = input_mask
|
1009 |
+
|
1010 |
+
if original_mask is None:
|
1011 |
+
raise gr.Error('Please generate mask first')
|
1012 |
+
|
1013 |
+
if original_mask.ndim == 2:
|
1014 |
+
original_mask = original_mask[:,:,None]
|
1015 |
+
|
1016 |
+
dilation_type = np.random.choice(['square_dilation'])
|
1017 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1018 |
+
|
1019 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1020 |
+
|
1021 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1022 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1023 |
+
masked_image = Image.fromarray(masked_image)
|
1024 |
+
|
1025 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1026 |
+
|
1027 |
+
|
1028 |
+
def process_erosion_mask(input_image,
|
1029 |
+
original_image,
|
1030 |
+
original_mask,
|
1031 |
+
resize_default,
|
1032 |
+
aspect_ratio_name,
|
1033 |
+
dilation_size=20):
|
1034 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1035 |
+
input_mask = np.asarray(alpha_mask)
|
1036 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1037 |
+
if output_w == "" or output_h == "":
|
1038 |
+
output_h, output_w = original_image.shape[:2]
|
1039 |
+
if resize_default:
|
1040 |
+
short_side = min(output_w, output_h)
|
1041 |
+
scale_ratio = 640 / short_side
|
1042 |
+
output_w = int(output_w * scale_ratio)
|
1043 |
+
output_h = int(output_h * scale_ratio)
|
1044 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1045 |
+
original_image = np.array(original_image)
|
1046 |
+
if input_mask is not None:
|
1047 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1048 |
+
input_mask = np.array(input_mask)
|
1049 |
+
if original_mask is not None:
|
1050 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1051 |
+
original_mask = np.array(original_mask)
|
1052 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1053 |
+
else:
|
1054 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1055 |
+
pass
|
1056 |
+
else:
|
1057 |
+
if resize_default:
|
1058 |
+
short_side = min(output_w, output_h)
|
1059 |
+
scale_ratio = 640 / short_side
|
1060 |
+
output_w = int(output_w * scale_ratio)
|
1061 |
+
output_h = int(output_h * scale_ratio)
|
1062 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1063 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1064 |
+
original_image = np.array(original_image)
|
1065 |
+
if input_mask is not None:
|
1066 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1067 |
+
input_mask = np.array(input_mask)
|
1068 |
+
if original_mask is not None:
|
1069 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1070 |
+
original_mask = np.array(original_mask)
|
1071 |
+
|
1072 |
+
if input_mask.max() == 0:
|
1073 |
+
original_mask = original_mask
|
1074 |
+
else:
|
1075 |
+
original_mask = input_mask
|
1076 |
+
|
1077 |
+
if original_mask is None:
|
1078 |
+
raise gr.Error('Please generate mask first')
|
1079 |
+
|
1080 |
+
if original_mask.ndim == 2:
|
1081 |
+
original_mask = original_mask[:,:,None]
|
1082 |
+
|
1083 |
+
dilation_type = np.random.choice(['square_erosion'])
|
1084 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1085 |
+
|
1086 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1087 |
+
|
1088 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1089 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1090 |
+
masked_image = Image.fromarray(masked_image)
|
1091 |
+
|
1092 |
+
|
1093 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1094 |
+
|
1095 |
+
|
1096 |
+
def move_mask_left(input_image,
|
1097 |
+
original_image,
|
1098 |
+
original_mask,
|
1099 |
+
moving_pixels,
|
1100 |
+
resize_default,
|
1101 |
+
aspect_ratio_name):
|
1102 |
+
|
1103 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1104 |
+
input_mask = np.asarray(alpha_mask)
|
1105 |
+
|
1106 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1107 |
+
if output_w == "" or output_h == "":
|
1108 |
+
output_h, output_w = original_image.shape[:2]
|
1109 |
+
if resize_default:
|
1110 |
+
short_side = min(output_w, output_h)
|
1111 |
+
scale_ratio = 640 / short_side
|
1112 |
+
output_w = int(output_w * scale_ratio)
|
1113 |
+
output_h = int(output_h * scale_ratio)
|
1114 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1115 |
+
original_image = np.array(original_image)
|
1116 |
+
if input_mask is not None:
|
1117 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1118 |
+
input_mask = np.array(input_mask)
|
1119 |
+
if original_mask is not None:
|
1120 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1121 |
+
original_mask = np.array(original_mask)
|
1122 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1123 |
+
else:
|
1124 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1125 |
+
pass
|
1126 |
+
else:
|
1127 |
+
if resize_default:
|
1128 |
+
short_side = min(output_w, output_h)
|
1129 |
+
scale_ratio = 640 / short_side
|
1130 |
+
output_w = int(output_w * scale_ratio)
|
1131 |
+
output_h = int(output_h * scale_ratio)
|
1132 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1133 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1134 |
+
original_image = np.array(original_image)
|
1135 |
+
if input_mask is not None:
|
1136 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1137 |
+
input_mask = np.array(input_mask)
|
1138 |
+
if original_mask is not None:
|
1139 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1140 |
+
original_mask = np.array(original_mask)
|
1141 |
+
|
1142 |
+
if input_mask.max() == 0:
|
1143 |
+
original_mask = original_mask
|
1144 |
+
else:
|
1145 |
+
original_mask = input_mask
|
1146 |
+
|
1147 |
+
if original_mask is None:
|
1148 |
+
raise gr.Error('Please generate mask first')
|
1149 |
+
|
1150 |
+
if original_mask.ndim == 2:
|
1151 |
+
original_mask = original_mask[:,:,None]
|
1152 |
+
|
1153 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1154 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1155 |
+
|
1156 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1157 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1158 |
+
masked_image = Image.fromarray(masked_image)
|
1159 |
+
|
1160 |
+
if moved_mask.max() <= 1:
|
1161 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1162 |
+
original_mask = moved_mask
|
1163 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1164 |
+
|
1165 |
+
|
1166 |
+
def move_mask_right(input_image,
|
1167 |
+
original_image,
|
1168 |
+
original_mask,
|
1169 |
+
moving_pixels,
|
1170 |
+
resize_default,
|
1171 |
+
aspect_ratio_name):
|
1172 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1173 |
+
input_mask = np.asarray(alpha_mask)
|
1174 |
+
|
1175 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1176 |
+
if output_w == "" or output_h == "":
|
1177 |
+
output_h, output_w = original_image.shape[:2]
|
1178 |
+
if resize_default:
|
1179 |
+
short_side = min(output_w, output_h)
|
1180 |
+
scale_ratio = 640 / short_side
|
1181 |
+
output_w = int(output_w * scale_ratio)
|
1182 |
+
output_h = int(output_h * scale_ratio)
|
1183 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1184 |
+
original_image = np.array(original_image)
|
1185 |
+
if input_mask is not None:
|
1186 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1187 |
+
input_mask = np.array(input_mask)
|
1188 |
+
if original_mask is not None:
|
1189 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1190 |
+
original_mask = np.array(original_mask)
|
1191 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1192 |
+
else:
|
1193 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1194 |
+
pass
|
1195 |
+
else:
|
1196 |
+
if resize_default:
|
1197 |
+
short_side = min(output_w, output_h)
|
1198 |
+
scale_ratio = 640 / short_side
|
1199 |
+
output_w = int(output_w * scale_ratio)
|
1200 |
+
output_h = int(output_h * scale_ratio)
|
1201 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1202 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1203 |
+
original_image = np.array(original_image)
|
1204 |
+
if input_mask is not None:
|
1205 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1206 |
+
input_mask = np.array(input_mask)
|
1207 |
+
if original_mask is not None:
|
1208 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1209 |
+
original_mask = np.array(original_mask)
|
1210 |
+
|
1211 |
+
if input_mask.max() == 0:
|
1212 |
+
original_mask = original_mask
|
1213 |
+
else:
|
1214 |
+
original_mask = input_mask
|
1215 |
+
|
1216 |
+
if original_mask is None:
|
1217 |
+
raise gr.Error('Please generate mask first')
|
1218 |
+
|
1219 |
+
if original_mask.ndim == 2:
|
1220 |
+
original_mask = original_mask[:,:,None]
|
1221 |
+
|
1222 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1223 |
+
|
1224 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1225 |
+
|
1226 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1227 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1228 |
+
masked_image = Image.fromarray(masked_image)
|
1229 |
+
|
1230 |
+
|
1231 |
+
if moved_mask.max() <= 1:
|
1232 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1233 |
+
original_mask = moved_mask
|
1234 |
+
|
1235 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1236 |
+
|
1237 |
+
|
1238 |
+
def move_mask_up(input_image,
|
1239 |
+
original_image,
|
1240 |
+
original_mask,
|
1241 |
+
moving_pixels,
|
1242 |
+
resize_default,
|
1243 |
+
aspect_ratio_name):
|
1244 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1245 |
+
input_mask = np.asarray(alpha_mask)
|
1246 |
+
|
1247 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1248 |
+
if output_w == "" or output_h == "":
|
1249 |
+
output_h, output_w = original_image.shape[:2]
|
1250 |
+
if resize_default:
|
1251 |
+
short_side = min(output_w, output_h)
|
1252 |
+
scale_ratio = 640 / short_side
|
1253 |
+
output_w = int(output_w * scale_ratio)
|
1254 |
+
output_h = int(output_h * scale_ratio)
|
1255 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1256 |
+
original_image = np.array(original_image)
|
1257 |
+
if input_mask is not None:
|
1258 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1259 |
+
input_mask = np.array(input_mask)
|
1260 |
+
if original_mask is not None:
|
1261 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1262 |
+
original_mask = np.array(original_mask)
|
1263 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1264 |
+
else:
|
1265 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1266 |
+
pass
|
1267 |
+
else:
|
1268 |
+
if resize_default:
|
1269 |
+
short_side = min(output_w, output_h)
|
1270 |
+
scale_ratio = 640 / short_side
|
1271 |
+
output_w = int(output_w * scale_ratio)
|
1272 |
+
output_h = int(output_h * scale_ratio)
|
1273 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1274 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1275 |
+
original_image = np.array(original_image)
|
1276 |
+
if input_mask is not None:
|
1277 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1278 |
+
input_mask = np.array(input_mask)
|
1279 |
+
if original_mask is not None:
|
1280 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1281 |
+
original_mask = np.array(original_mask)
|
1282 |
+
|
1283 |
+
if input_mask.max() == 0:
|
1284 |
+
original_mask = original_mask
|
1285 |
+
else:
|
1286 |
+
original_mask = input_mask
|
1287 |
+
|
1288 |
+
if original_mask is None:
|
1289 |
+
raise gr.Error('Please generate mask first')
|
1290 |
+
|
1291 |
+
if original_mask.ndim == 2:
|
1292 |
+
original_mask = original_mask[:,:,None]
|
1293 |
+
|
1294 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1295 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1296 |
+
|
1297 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1298 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1299 |
+
masked_image = Image.fromarray(masked_image)
|
1300 |
+
|
1301 |
+
if moved_mask.max() <= 1:
|
1302 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1303 |
+
original_mask = moved_mask
|
1304 |
+
|
1305 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1306 |
+
|
1307 |
+
|
1308 |
+
def move_mask_down(input_image,
|
1309 |
+
original_image,
|
1310 |
+
original_mask,
|
1311 |
+
moving_pixels,
|
1312 |
+
resize_default,
|
1313 |
+
aspect_ratio_name):
|
1314 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1315 |
+
input_mask = np.asarray(alpha_mask)
|
1316 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1317 |
+
if output_w == "" or output_h == "":
|
1318 |
+
output_h, output_w = original_image.shape[:2]
|
1319 |
+
if resize_default:
|
1320 |
+
short_side = min(output_w, output_h)
|
1321 |
+
scale_ratio = 640 / short_side
|
1322 |
+
output_w = int(output_w * scale_ratio)
|
1323 |
+
output_h = int(output_h * scale_ratio)
|
1324 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1325 |
+
original_image = np.array(original_image)
|
1326 |
+
if input_mask is not None:
|
1327 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1328 |
+
input_mask = np.array(input_mask)
|
1329 |
+
if original_mask is not None:
|
1330 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1331 |
+
original_mask = np.array(original_mask)
|
1332 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1333 |
+
else:
|
1334 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1335 |
+
pass
|
1336 |
+
else:
|
1337 |
+
if resize_default:
|
1338 |
+
short_side = min(output_w, output_h)
|
1339 |
+
scale_ratio = 640 / short_side
|
1340 |
+
output_w = int(output_w * scale_ratio)
|
1341 |
+
output_h = int(output_h * scale_ratio)
|
1342 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1343 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1344 |
+
original_image = np.array(original_image)
|
1345 |
+
if input_mask is not None:
|
1346 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1347 |
+
input_mask = np.array(input_mask)
|
1348 |
+
if original_mask is not None:
|
1349 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1350 |
+
original_mask = np.array(original_mask)
|
1351 |
+
|
1352 |
+
if input_mask.max() == 0:
|
1353 |
+
original_mask = original_mask
|
1354 |
+
else:
|
1355 |
+
original_mask = input_mask
|
1356 |
+
|
1357 |
+
if original_mask is None:
|
1358 |
+
raise gr.Error('Please generate mask first')
|
1359 |
+
|
1360 |
+
if original_mask.ndim == 2:
|
1361 |
+
original_mask = original_mask[:,:,None]
|
1362 |
+
|
1363 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1364 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1365 |
+
|
1366 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1367 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1368 |
+
masked_image = Image.fromarray(masked_image)
|
1369 |
+
|
1370 |
+
if moved_mask.max() <= 1:
|
1371 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1372 |
+
original_mask = moved_mask
|
1373 |
+
|
1374 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1375 |
+
|
1376 |
+
|
1377 |
+
def invert_mask(input_image,
|
1378 |
+
original_image,
|
1379 |
+
original_mask,
|
1380 |
+
):
|
1381 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1382 |
+
input_mask = np.asarray(alpha_mask)
|
1383 |
+
if input_mask.max() == 0:
|
1384 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1385 |
+
else:
|
1386 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1387 |
+
|
1388 |
+
if original_mask is None:
|
1389 |
+
raise gr.Error('Please generate mask first')
|
1390 |
+
|
1391 |
+
original_mask = original_mask.squeeze()
|
1392 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1393 |
+
|
1394 |
+
if original_mask.ndim == 2:
|
1395 |
+
original_mask = original_mask[:,:,None]
|
1396 |
+
|
1397 |
+
if original_mask.max() <= 1:
|
1398 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1399 |
+
|
1400 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1401 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1402 |
+
masked_image = Image.fromarray(masked_image)
|
1403 |
+
|
1404 |
+
return [masked_image], [mask_image], original_mask, True
|
1405 |
+
|
1406 |
+
|
1407 |
+
def init_img(base,
|
1408 |
+
init_type,
|
1409 |
+
prompt,
|
1410 |
+
aspect_ratio,
|
1411 |
+
example_change_times
|
1412 |
+
):
|
1413 |
+
image_pil = base["background"].convert("RGB")
|
1414 |
+
original_image = np.array(image_pil)
|
1415 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1416 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1417 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1418 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1419 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1420 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1421 |
+
width, height = image_pil.size
|
1422 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1423 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1424 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1425 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1426 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1427 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1428 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1429 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1430 |
+
else:
|
1431 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1432 |
+
aspect_ratio = "Custom resolution"
|
1433 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1434 |
+
|
1435 |
+
|
1436 |
+
def reset_func(input_image,
|
1437 |
+
original_image,
|
1438 |
+
original_mask,
|
1439 |
+
prompt,
|
1440 |
+
target_prompt,
|
1441 |
+
):
|
1442 |
+
input_image = None
|
1443 |
+
original_image = None
|
1444 |
+
original_mask = None
|
1445 |
+
prompt = ''
|
1446 |
+
mask_gallery = []
|
1447 |
+
masked_gallery = []
|
1448 |
+
result_gallery = []
|
1449 |
+
target_prompt = ''
|
1450 |
+
if torch.cuda.is_available():
|
1451 |
+
torch.cuda.empty_cache()
|
1452 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1453 |
+
|
1454 |
+
|
1455 |
+
def update_example(example_type,
|
1456 |
+
prompt,
|
1457 |
+
example_change_times):
|
1458 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1459 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1460 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1461 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1462 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1463 |
+
width, height = image_pil.size
|
1464 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1465 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1466 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1467 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1468 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1469 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1470 |
+
|
1471 |
+
original_image = np.array(image_pil)
|
1472 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1473 |
+
aspect_ratio = "Custom resolution"
|
1474 |
+
example_change_times += 1
|
1475 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1476 |
+
|
1477 |
+
|
1478 |
+
block = gr.Blocks(
|
1479 |
+
theme=gr.themes.Soft(
|
1480 |
+
radius_size=gr.themes.sizes.radius_none,
|
1481 |
+
text_size=gr.themes.sizes.text_md
|
1482 |
+
)
|
1483 |
+
)
|
1484 |
+
with block as demo:
|
1485 |
+
with gr.Row():
|
1486 |
+
with gr.Column():
|
1487 |
+
gr.HTML(head)
|
1488 |
+
|
1489 |
+
gr.Markdown(descriptions)
|
1490 |
+
|
1491 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1492 |
+
with gr.Row(equal_height=True):
|
1493 |
+
gr.Markdown(instructions)
|
1494 |
+
|
1495 |
+
original_image = gr.State(value=None)
|
1496 |
+
original_mask = gr.State(value=None)
|
1497 |
+
category = gr.State(value=None)
|
1498 |
+
status = gr.State(value=None)
|
1499 |
+
invert_mask_state = gr.State(value=False)
|
1500 |
+
example_change_times = gr.State(value=0)
|
1501 |
+
|
1502 |
+
|
1503 |
+
with gr.Row():
|
1504 |
+
with gr.Column():
|
1505 |
+
with gr.Row():
|
1506 |
+
input_image = gr.ImageEditor(
|
1507 |
+
label="Input Image",
|
1508 |
+
type="pil",
|
1509 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1510 |
+
layers = False,
|
1511 |
+
interactive=True,
|
1512 |
+
height=1024,
|
1513 |
+
sources=["upload"],
|
1514 |
+
placeholder="Please click here or the icon below to upload the image.",
|
1515 |
+
)
|
1516 |
+
|
1517 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1518 |
+
run_button = gr.Button("💫 Run")
|
1519 |
+
|
1520 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1521 |
+
with gr.Group():
|
1522 |
+
with gr.Row():
|
1523 |
+
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1524 |
+
# GPT4o_KEY = gr.Textbox(type="password", value="sk-d145b963a92649a88843caeb741e8bbc")
|
1525 |
+
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1526 |
+
|
1527 |
+
GPT4o_KEY_submit = gr.Button("Submit and Verify")
|
1528 |
+
|
1529 |
+
|
1530 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1531 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1532 |
+
|
1533 |
+
with gr.Row():
|
1534 |
+
mask_button = gr.Button("Generate Mask")
|
1535 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1536 |
+
|
1537 |
+
|
1538 |
+
with gr.Row():
|
1539 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1540 |
+
|
1541 |
+
target_prompt = gr.Text(
|
1542 |
+
label="Input Target Prompt",
|
1543 |
+
max_lines=5,
|
1544 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1545 |
+
value='',
|
1546 |
+
lines=2
|
1547 |
+
)
|
1548 |
+
|
1549 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1550 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1551 |
+
negative_prompt = gr.Text(
|
1552 |
+
label="Negative Prompt",
|
1553 |
+
max_lines=5,
|
1554 |
+
placeholder="Please input your negative prompt",
|
1555 |
+
value='ugly, low quality',lines=1
|
1556 |
+
)
|
1557 |
+
|
1558 |
+
control_strength = gr.Slider(
|
1559 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1560 |
+
)
|
1561 |
+
with gr.Group():
|
1562 |
+
seed = gr.Slider(
|
1563 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1564 |
+
)
|
1565 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1566 |
+
|
1567 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1568 |
+
|
1569 |
+
|
1570 |
+
num_samples = gr.Slider(
|
1571 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1572 |
+
)
|
1573 |
+
|
1574 |
+
with gr.Group():
|
1575 |
+
with gr.Row():
|
1576 |
+
guidance_scale = gr.Slider(
|
1577 |
+
label="Guidance scale",
|
1578 |
+
minimum=1,
|
1579 |
+
maximum=12,
|
1580 |
+
step=0.1,
|
1581 |
+
value=7.5,
|
1582 |
+
)
|
1583 |
+
num_inference_steps = gr.Slider(
|
1584 |
+
label="Number of inference steps",
|
1585 |
+
minimum=1,
|
1586 |
+
maximum=50,
|
1587 |
+
step=1,
|
1588 |
+
value=50,
|
1589 |
+
)
|
1590 |
+
|
1591 |
+
|
1592 |
+
with gr.Column():
|
1593 |
+
with gr.Row():
|
1594 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1595 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1596 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1597 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1598 |
+
|
1599 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1600 |
+
dilation_size = gr.Slider(
|
1601 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1602 |
+
)
|
1603 |
+
with gr.Row():
|
1604 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1605 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1606 |
+
|
1607 |
+
moving_pixels = gr.Slider(
|
1608 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1609 |
+
)
|
1610 |
+
with gr.Row():
|
1611 |
+
move_left_button = gr.Button("Move Left")
|
1612 |
+
move_right_button = gr.Button("Move Right")
|
1613 |
+
with gr.Row():
|
1614 |
+
move_up_button = gr.Button("Move Up")
|
1615 |
+
move_down_button = gr.Button("Move Down")
|
1616 |
+
|
1617 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1618 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1619 |
+
|
1620 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1621 |
+
|
1622 |
+
reset_button = gr.Button("Reset")
|
1623 |
+
|
1624 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1625 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1626 |
+
|
1627 |
+
|
1628 |
+
|
1629 |
+
with gr.Row():
|
1630 |
+
example = gr.Examples(
|
1631 |
+
label="Quick Example",
|
1632 |
+
examples=EXAMPLES,
|
1633 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1634 |
+
examples_per_page=10,
|
1635 |
+
cache_examples=False,
|
1636 |
+
)
|
1637 |
+
|
1638 |
+
|
1639 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1640 |
+
with gr.Row(equal_height=True):
|
1641 |
+
gr.Markdown(tips)
|
1642 |
+
|
1643 |
+
with gr.Row():
|
1644 |
+
gr.Markdown(citation)
|
1645 |
+
|
1646 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1647 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1648 |
+
input_image.upload(
|
1649 |
+
init_img,
|
1650 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1651 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1652 |
+
)
|
1653 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1654 |
+
|
1655 |
+
## vlm and base model dropdown
|
1656 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1657 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1658 |
+
|
1659 |
+
|
1660 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1661 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1662 |
+
|
1663 |
+
|
1664 |
+
ips=[input_image,
|
1665 |
+
original_image,
|
1666 |
+
original_mask,
|
1667 |
+
prompt,
|
1668 |
+
negative_prompt,
|
1669 |
+
control_strength,
|
1670 |
+
seed,
|
1671 |
+
randomize_seed,
|
1672 |
+
guidance_scale,
|
1673 |
+
num_inference_steps,
|
1674 |
+
num_samples,
|
1675 |
+
blending,
|
1676 |
+
category,
|
1677 |
+
target_prompt,
|
1678 |
+
resize_default,
|
1679 |
+
aspect_ratio,
|
1680 |
+
invert_mask_state]
|
1681 |
+
|
1682 |
+
## run brushedit
|
1683 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1684 |
+
|
1685 |
+
## mask func
|
1686 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1687 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1688 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1689 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1690 |
+
|
1691 |
+
## move mask func
|
1692 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1693 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1694 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1695 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1696 |
+
|
1697 |
+
## prompt func
|
1698 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1699 |
+
|
1700 |
+
## reset func
|
1701 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1702 |
+
|
1703 |
+
# if have a localhost access error, try to use the following code
|
1704 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1705 |
+
# demo.launch()
|
brushedit_app_315_0.py
ADDED
@@ -0,0 +1,1696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
# base_openai_url = ""
|
39 |
+
|
40 |
+
#### Description ####
|
41 |
+
logo = r"""
|
42 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
43 |
+
"""
|
44 |
+
head = r"""
|
45 |
+
<div style="text-align: center;">
|
46 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
47 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
48 |
+
<a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
49 |
+
<a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
50 |
+
<a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
51 |
+
|
52 |
+
</div>
|
53 |
+
</br>
|
54 |
+
</div>
|
55 |
+
"""
|
56 |
+
descriptions = r"""
|
57 |
+
Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
|
58 |
+
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
|
59 |
+
"""
|
60 |
+
|
61 |
+
instructions = r"""
|
62 |
+
Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
|
63 |
+
|
64 |
+
🛠️ <b>Fully automated instruction-based editing</b>:
|
65 |
+
<ul>
|
66 |
+
<li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
|
67 |
+
<li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
|
68 |
+
<li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
|
69 |
+
</ul>
|
70 |
+
|
71 |
+
🛠️ <b>Interactive instruction-based editing</b>:
|
72 |
+
<ul>
|
73 |
+
<li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
|
74 |
+
<li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;"> to restore. </li>
|
75 |
+
<li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
|
76 |
+
<li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
|
77 |
+
</ul>
|
78 |
+
|
79 |
+
<b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
|
80 |
+
|
81 |
+
<b> We recommend zooming out in your browser for a better viewing range and experience. </b>
|
82 |
+
|
83 |
+
<b> For more detailed feature descriptions, see the bottom. </b>
|
84 |
+
|
85 |
+
☕️ Have fun! 🎄 Wishing you a merry Christmas!
|
86 |
+
"""
|
87 |
+
|
88 |
+
tips = r"""
|
89 |
+
💡 <b>Some Tips</b>:
|
90 |
+
<ul>
|
91 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
|
92 |
+
<li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
|
93 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
|
94 |
+
</ul>
|
95 |
+
|
96 |
+
💡 <b>Detailed Features</b>:
|
97 |
+
<ul>
|
98 |
+
<li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
|
99 |
+
<li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
|
100 |
+
<li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
|
101 |
+
<li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
|
102 |
+
<li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
|
103 |
+
<li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
|
104 |
+
<li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
|
105 |
+
<li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
|
106 |
+
<li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
|
107 |
+
<li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
|
108 |
+
<li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
|
109 |
+
</ul>
|
110 |
+
|
111 |
+
💡 <b>Advanced Features</b>:
|
112 |
+
<ul>
|
113 |
+
<li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
|
114 |
+
<li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
|
115 |
+
<li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
|
116 |
+
<li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
|
117 |
+
<li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
|
118 |
+
<li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
|
119 |
+
</ul>
|
120 |
+
|
121 |
+
|
122 |
+
"""
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
citation = r"""
|
127 |
+
If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
|
128 |
+
[](https://github.com/TencentARC/BrushEdit)
|
129 |
+
---
|
130 |
+
📝 **Citation**
|
131 |
+
<br>
|
132 |
+
If our work is useful for your research, please consider citing:
|
133 |
+
```bibtex
|
134 |
+
@misc{li2024brushedit,
|
135 |
+
title={BrushEdit: All-In-One Image Inpainting and Editing},
|
136 |
+
author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
|
137 |
+
year={2024},
|
138 |
+
eprint={2412.10316},
|
139 |
+
archivePrefix={arXiv},
|
140 |
+
primaryClass={cs.CV}
|
141 |
+
}
|
142 |
+
```
|
143 |
+
📧 **Contact**
|
144 |
+
<br>
|
145 |
+
If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# - - - - - examples - - - - - #
|
149 |
+
EXAMPLES = [
|
150 |
+
|
151 |
+
[
|
152 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
153 |
+
"add a magic hat on frog head.",
|
154 |
+
642087011,
|
155 |
+
"frog",
|
156 |
+
"frog",
|
157 |
+
True,
|
158 |
+
False,
|
159 |
+
"GPT4-o (Highly Recommended)"
|
160 |
+
],
|
161 |
+
[
|
162 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
163 |
+
"replace the background to ancient China.",
|
164 |
+
648464818,
|
165 |
+
"chinese_girl",
|
166 |
+
"chinese_girl",
|
167 |
+
True,
|
168 |
+
False,
|
169 |
+
"GPT4-o (Highly Recommended)"
|
170 |
+
],
|
171 |
+
[
|
172 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
173 |
+
"remove the deer.",
|
174 |
+
648464818,
|
175 |
+
"angel_christmas",
|
176 |
+
"angel_christmas",
|
177 |
+
False,
|
178 |
+
False,
|
179 |
+
"GPT4-o (Highly Recommended)"
|
180 |
+
],
|
181 |
+
[
|
182 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
183 |
+
"add a wreath on head.",
|
184 |
+
648464818,
|
185 |
+
"sunflower_girl",
|
186 |
+
"sunflower_girl",
|
187 |
+
True,
|
188 |
+
False,
|
189 |
+
"GPT4-o (Highly Recommended)"
|
190 |
+
],
|
191 |
+
[
|
192 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
193 |
+
"add a butterfly fairy.",
|
194 |
+
648464818,
|
195 |
+
"girl_on_sun",
|
196 |
+
"girl_on_sun",
|
197 |
+
True,
|
198 |
+
False,
|
199 |
+
"GPT4-o (Highly Recommended)"
|
200 |
+
],
|
201 |
+
[
|
202 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
203 |
+
"remove the christmas hat.",
|
204 |
+
642087011,
|
205 |
+
"spider_man_rm",
|
206 |
+
"spider_man_rm",
|
207 |
+
False,
|
208 |
+
False,
|
209 |
+
"GPT4-o (Highly Recommended)"
|
210 |
+
],
|
211 |
+
[
|
212 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
213 |
+
"remove the flower.",
|
214 |
+
642087011,
|
215 |
+
"anime_flower",
|
216 |
+
"anime_flower",
|
217 |
+
False,
|
218 |
+
False,
|
219 |
+
"GPT4-o (Highly Recommended)"
|
220 |
+
],
|
221 |
+
[
|
222 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
223 |
+
"replace the clothes to a delicated floral skirt.",
|
224 |
+
648464818,
|
225 |
+
"chenduling",
|
226 |
+
"chenduling",
|
227 |
+
True,
|
228 |
+
False,
|
229 |
+
"GPT4-o (Highly Recommended)"
|
230 |
+
],
|
231 |
+
[
|
232 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
233 |
+
"make the hedgehog in Italy.",
|
234 |
+
648464818,
|
235 |
+
"hedgehog_rp_bg",
|
236 |
+
"hedgehog_rp_bg",
|
237 |
+
True,
|
238 |
+
False,
|
239 |
+
"GPT4-o (Highly Recommended)"
|
240 |
+
],
|
241 |
+
|
242 |
+
]
|
243 |
+
|
244 |
+
INPUT_IMAGE_PATH = {
|
245 |
+
"frog": "./assets/frog/frog.jpeg",
|
246 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
247 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
248 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
249 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
250 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
251 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
252 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
253 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
254 |
+
}
|
255 |
+
MASK_IMAGE_PATH = {
|
256 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
257 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
258 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
259 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
260 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
261 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
262 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
263 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
264 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
265 |
+
}
|
266 |
+
MASKED_IMAGE_PATH = {
|
267 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
268 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
269 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
270 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
271 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
272 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
273 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
274 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
275 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
276 |
+
}
|
277 |
+
OUTPUT_IMAGE_PATH = {
|
278 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
279 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
280 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
281 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
282 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
283 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
284 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
285 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
286 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
287 |
+
}
|
288 |
+
|
289 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
290 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
291 |
+
|
292 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
293 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
294 |
+
BASE_MODELS = list(base_models_template.keys())
|
295 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
296 |
+
|
297 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
298 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
299 |
+
|
300 |
+
device = "cuda"
|
301 |
+
torch_dtype = torch.bfloat16
|
302 |
+
|
303 |
+
## init device
|
304 |
+
# try:
|
305 |
+
# if torch.cuda.is_available():
|
306 |
+
# device = "cuda"
|
307 |
+
# print("device = cuda")
|
308 |
+
# elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
309 |
+
# device = "mps"
|
310 |
+
# print("device = mps")
|
311 |
+
# else:
|
312 |
+
# device = "cpu"
|
313 |
+
# print("device = cpu")
|
314 |
+
# except:
|
315 |
+
# device = "cpu"
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
# download hf models
|
320 |
+
BrushEdit_path = "models/"
|
321 |
+
if not os.path.exists(BrushEdit_path):
|
322 |
+
BrushEdit_path = snapshot_download(
|
323 |
+
repo_id="TencentARC/BrushEdit",
|
324 |
+
local_dir=BrushEdit_path,
|
325 |
+
token=os.getenv("HF_TOKEN"),
|
326 |
+
)
|
327 |
+
|
328 |
+
## init default VLM
|
329 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
330 |
+
if vlm_processor != "" and vlm_model != "":
|
331 |
+
vlm_model.to(device)
|
332 |
+
else:
|
333 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
334 |
+
|
335 |
+
|
336 |
+
## init base model
|
337 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
338 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
339 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
340 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
341 |
+
|
342 |
+
|
343 |
+
# input brushnetX ckpt path
|
344 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
345 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
346 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
347 |
+
)
|
348 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
349 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
350 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
351 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
352 |
+
pipe.enable_model_cpu_offload()
|
353 |
+
|
354 |
+
|
355 |
+
## init SAM
|
356 |
+
sam = build_sam(checkpoint=sam_path)
|
357 |
+
sam.to(device=device)
|
358 |
+
sam_predictor = SamPredictor(sam)
|
359 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
360 |
+
|
361 |
+
## init groundingdino_model
|
362 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
363 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
364 |
+
|
365 |
+
## Ordinary function
|
366 |
+
def crop_and_resize(image: Image.Image,
|
367 |
+
target_width: int,
|
368 |
+
target_height: int) -> Image.Image:
|
369 |
+
"""
|
370 |
+
Crops and resizes an image while preserving the aspect ratio.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
374 |
+
target_width (int): Target width of the output image.
|
375 |
+
target_height (int): Target height of the output image.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
Image.Image: Cropped and resized image.
|
379 |
+
"""
|
380 |
+
# Original dimensions
|
381 |
+
original_width, original_height = image.size
|
382 |
+
original_aspect = original_width / original_height
|
383 |
+
target_aspect = target_width / target_height
|
384 |
+
|
385 |
+
# Calculate crop box to maintain aspect ratio
|
386 |
+
if original_aspect > target_aspect:
|
387 |
+
# Crop horizontally
|
388 |
+
new_width = int(original_height * target_aspect)
|
389 |
+
new_height = original_height
|
390 |
+
left = (original_width - new_width) / 2
|
391 |
+
top = 0
|
392 |
+
right = left + new_width
|
393 |
+
bottom = original_height
|
394 |
+
else:
|
395 |
+
# Crop vertically
|
396 |
+
new_width = original_width
|
397 |
+
new_height = int(original_width / target_aspect)
|
398 |
+
left = 0
|
399 |
+
top = (original_height - new_height) / 2
|
400 |
+
right = original_width
|
401 |
+
bottom = top + new_height
|
402 |
+
|
403 |
+
# Crop and resize
|
404 |
+
cropped_image = image.crop((left, top, right, bottom))
|
405 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
406 |
+
return resized_image
|
407 |
+
|
408 |
+
|
409 |
+
## Ordinary function
|
410 |
+
def resize(image: Image.Image,
|
411 |
+
target_width: int,
|
412 |
+
target_height: int) -> Image.Image:
|
413 |
+
"""
|
414 |
+
Crops and resizes an image while preserving the aspect ratio.
|
415 |
+
|
416 |
+
Args:
|
417 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
418 |
+
target_width (int): Target width of the output image.
|
419 |
+
target_height (int): Target height of the output image.
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
Image.Image: Cropped and resized image.
|
423 |
+
"""
|
424 |
+
# Original dimensions
|
425 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
426 |
+
return resized_image
|
427 |
+
|
428 |
+
|
429 |
+
def move_mask_func(mask, direction, units):
|
430 |
+
binary_mask = mask.squeeze()>0
|
431 |
+
rows, cols = binary_mask.shape
|
432 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
433 |
+
|
434 |
+
if direction == 'down':
|
435 |
+
# move down
|
436 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
437 |
+
|
438 |
+
elif direction == 'up':
|
439 |
+
# move up
|
440 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
441 |
+
|
442 |
+
elif direction == 'right':
|
443 |
+
# move left
|
444 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
445 |
+
|
446 |
+
elif direction == 'left':
|
447 |
+
# move right
|
448 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
449 |
+
|
450 |
+
return moved_mask
|
451 |
+
|
452 |
+
|
453 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
454 |
+
# Randomly select the size of dilation
|
455 |
+
binary_mask = mask.squeeze()>0
|
456 |
+
|
457 |
+
if dilation_type == 'square_dilation':
|
458 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
459 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
460 |
+
elif dilation_type == 'square_erosion':
|
461 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
462 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
463 |
+
elif dilation_type == 'bounding_box':
|
464 |
+
# find the most left top and left bottom point
|
465 |
+
rows, cols = np.where(binary_mask)
|
466 |
+
if len(rows) == 0 or len(cols) == 0:
|
467 |
+
return mask # return original mask if no valid points
|
468 |
+
|
469 |
+
min_row = np.min(rows)
|
470 |
+
max_row = np.max(rows)
|
471 |
+
min_col = np.min(cols)
|
472 |
+
max_col = np.max(cols)
|
473 |
+
|
474 |
+
# create a bounding box
|
475 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
476 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
477 |
+
|
478 |
+
elif dilation_type == 'bounding_ellipse':
|
479 |
+
# find the most left top and left bottom point
|
480 |
+
rows, cols = np.where(binary_mask)
|
481 |
+
if len(rows) == 0 or len(cols) == 0:
|
482 |
+
return mask # return original mask if no valid points
|
483 |
+
|
484 |
+
min_row = np.min(rows)
|
485 |
+
max_row = np.max(rows)
|
486 |
+
min_col = np.min(cols)
|
487 |
+
max_col = np.max(cols)
|
488 |
+
|
489 |
+
# calculate the center and axis length of the ellipse
|
490 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
491 |
+
a = (max_col - min_col) // 2 # half long axis
|
492 |
+
b = (max_row - min_row) // 2 # half short axis
|
493 |
+
|
494 |
+
# create a bounding ellipse
|
495 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
496 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
497 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
498 |
+
dilated_mask[ellipse_mask] = True
|
499 |
+
else:
|
500 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
501 |
+
|
502 |
+
# use binary dilation
|
503 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
504 |
+
return dilated_mask
|
505 |
+
|
506 |
+
|
507 |
+
## Gradio component function
|
508 |
+
def update_vlm_model(vlm_name):
|
509 |
+
global vlm_model, vlm_processor
|
510 |
+
if vlm_model is not None:
|
511 |
+
del vlm_model
|
512 |
+
torch.cuda.empty_cache()
|
513 |
+
|
514 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
515 |
+
|
516 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
517 |
+
if vlm_type == "llava-next":
|
518 |
+
if vlm_processor != "" and vlm_model != "":
|
519 |
+
vlm_model.to(device)
|
520 |
+
return vlm_model_dropdown
|
521 |
+
else:
|
522 |
+
if os.path.exists(vlm_local_path):
|
523 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
524 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
525 |
+
else:
|
526 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
527 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
528 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
529 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
530 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
531 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
532 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
533 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
534 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
535 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
536 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
537 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
538 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
539 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
540 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
541 |
+
elif vlm_type == "qwen2-vl":
|
542 |
+
if vlm_processor != "" and vlm_model != "":
|
543 |
+
vlm_model.to(device)
|
544 |
+
return vlm_model_dropdown
|
545 |
+
else:
|
546 |
+
if os.path.exists(vlm_local_path):
|
547 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
548 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
549 |
+
else:
|
550 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
551 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
552 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
553 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
554 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
555 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
556 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
557 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
558 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
559 |
+
elif vlm_type == "openai":
|
560 |
+
pass
|
561 |
+
return "success"
|
562 |
+
|
563 |
+
|
564 |
+
def update_base_model(base_model_name):
|
565 |
+
global pipe
|
566 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
567 |
+
if pipe is not None:
|
568 |
+
del pipe
|
569 |
+
torch.cuda.empty_cache()
|
570 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
571 |
+
if pipe != "":
|
572 |
+
pipe.to(device)
|
573 |
+
else:
|
574 |
+
if os.path.exists(base_model_path):
|
575 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
576 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
577 |
+
)
|
578 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
579 |
+
pipe.enable_model_cpu_offload()
|
580 |
+
else:
|
581 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
582 |
+
return "success"
|
583 |
+
|
584 |
+
|
585 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
586 |
+
global vlm_model, vlm_processor
|
587 |
+
if vlm_model is not None:
|
588 |
+
del vlm_model
|
589 |
+
torch.cuda.empty_cache()
|
590 |
+
try:
|
591 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY)
|
592 |
+
vlm_processor = ""
|
593 |
+
response = vlm_model.chat.completions.create(
|
594 |
+
model="gpt-4o-2024-08-06",
|
595 |
+
messages=[
|
596 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
597 |
+
{"role": "user", "content": "Say this is a test"}
|
598 |
+
]
|
599 |
+
)
|
600 |
+
response_str = response.choices[0].message.content
|
601 |
+
|
602 |
+
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
|
603 |
+
except Exception as e:
|
604 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
605 |
+
|
606 |
+
|
607 |
+
|
608 |
+
def process(input_image,
|
609 |
+
original_image,
|
610 |
+
original_mask,
|
611 |
+
prompt,
|
612 |
+
negative_prompt,
|
613 |
+
control_strength,
|
614 |
+
seed,
|
615 |
+
randomize_seed,
|
616 |
+
guidance_scale,
|
617 |
+
num_inference_steps,
|
618 |
+
num_samples,
|
619 |
+
blending,
|
620 |
+
category,
|
621 |
+
target_prompt,
|
622 |
+
resize_default,
|
623 |
+
aspect_ratio_name,
|
624 |
+
invert_mask_state):
|
625 |
+
if original_image is None:
|
626 |
+
if input_image is None:
|
627 |
+
raise gr.Error('Please upload the input image')
|
628 |
+
else:
|
629 |
+
image_pil = input_image["background"].convert("RGB")
|
630 |
+
original_image = np.array(image_pil)
|
631 |
+
if prompt is None or prompt == "":
|
632 |
+
if target_prompt is None or target_prompt == "":
|
633 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
634 |
+
|
635 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
636 |
+
input_mask = np.asarray(alpha_mask)
|
637 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
638 |
+
if output_w == "" or output_h == "":
|
639 |
+
output_h, output_w = original_image.shape[:2]
|
640 |
+
|
641 |
+
if resize_default:
|
642 |
+
short_side = min(output_w, output_h)
|
643 |
+
scale_ratio = 640 / short_side
|
644 |
+
output_w = int(output_w * scale_ratio)
|
645 |
+
output_h = int(output_h * scale_ratio)
|
646 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
647 |
+
original_image = np.array(original_image)
|
648 |
+
if input_mask is not None:
|
649 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
650 |
+
input_mask = np.array(input_mask)
|
651 |
+
if original_mask is not None:
|
652 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
653 |
+
original_mask = np.array(original_mask)
|
654 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
655 |
+
else:
|
656 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
657 |
+
pass
|
658 |
+
else:
|
659 |
+
if resize_default:
|
660 |
+
short_side = min(output_w, output_h)
|
661 |
+
scale_ratio = 640 / short_side
|
662 |
+
output_w = int(output_w * scale_ratio)
|
663 |
+
output_h = int(output_h * scale_ratio)
|
664 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
665 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
666 |
+
original_image = np.array(original_image)
|
667 |
+
if input_mask is not None:
|
668 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
669 |
+
input_mask = np.array(input_mask)
|
670 |
+
if original_mask is not None:
|
671 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
672 |
+
original_mask = np.array(original_mask)
|
673 |
+
|
674 |
+
if invert_mask_state:
|
675 |
+
original_mask = original_mask
|
676 |
+
else:
|
677 |
+
if input_mask.max() == 0:
|
678 |
+
original_mask = original_mask
|
679 |
+
else:
|
680 |
+
original_mask = input_mask
|
681 |
+
|
682 |
+
|
683 |
+
## inpainting directly if target_prompt is not None
|
684 |
+
if category is not None:
|
685 |
+
pass
|
686 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
687 |
+
pass
|
688 |
+
else:
|
689 |
+
try:
|
690 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
691 |
+
except Exception as e:
|
692 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
693 |
+
|
694 |
+
|
695 |
+
if original_mask is not None:
|
696 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
697 |
+
else:
|
698 |
+
try:
|
699 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
700 |
+
vlm_processor,
|
701 |
+
vlm_model,
|
702 |
+
original_image,
|
703 |
+
category,
|
704 |
+
prompt,
|
705 |
+
device)
|
706 |
+
|
707 |
+
original_mask = vlm_response_mask(vlm_processor,
|
708 |
+
vlm_model,
|
709 |
+
category,
|
710 |
+
original_image,
|
711 |
+
prompt,
|
712 |
+
object_wait_for_edit,
|
713 |
+
sam,
|
714 |
+
sam_predictor,
|
715 |
+
sam_automask_generator,
|
716 |
+
groundingdino_model,
|
717 |
+
device).astype(np.uint8)
|
718 |
+
except Exception as e:
|
719 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
720 |
+
|
721 |
+
if original_mask.ndim == 2:
|
722 |
+
original_mask = original_mask[:,:,None]
|
723 |
+
|
724 |
+
|
725 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
726 |
+
prompt_after_apply_instruction = target_prompt
|
727 |
+
|
728 |
+
else:
|
729 |
+
try:
|
730 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
731 |
+
vlm_processor,
|
732 |
+
vlm_model,
|
733 |
+
original_image,
|
734 |
+
prompt,
|
735 |
+
device)
|
736 |
+
except Exception as e:
|
737 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
738 |
+
|
739 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
740 |
+
|
741 |
+
|
742 |
+
with torch.autocast(device):
|
743 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
744 |
+
prompt_after_apply_instruction,
|
745 |
+
original_mask,
|
746 |
+
original_image,
|
747 |
+
generator,
|
748 |
+
num_inference_steps,
|
749 |
+
guidance_scale,
|
750 |
+
control_strength,
|
751 |
+
negative_prompt,
|
752 |
+
num_samples,
|
753 |
+
blending)
|
754 |
+
original_image = np.array(init_image_np)
|
755 |
+
masked_image = original_image * (1 - (mask_np>0))
|
756 |
+
masked_image = masked_image.astype(np.uint8)
|
757 |
+
masked_image = Image.fromarray(masked_image)
|
758 |
+
# Save the images (optional)
|
759 |
+
# import uuid
|
760 |
+
# uuid = str(uuid.uuid4())
|
761 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
762 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
763 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
764 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
765 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
766 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
767 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
768 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
769 |
+
|
770 |
+
|
771 |
+
def generate_target_prompt(input_image,
|
772 |
+
original_image,
|
773 |
+
prompt):
|
774 |
+
# load example image
|
775 |
+
if isinstance(original_image, str):
|
776 |
+
original_image = input_image
|
777 |
+
|
778 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
779 |
+
vlm_processor,
|
780 |
+
vlm_model,
|
781 |
+
original_image,
|
782 |
+
prompt,
|
783 |
+
device)
|
784 |
+
return prompt_after_apply_instruction
|
785 |
+
|
786 |
+
|
787 |
+
def process_mask(input_image,
|
788 |
+
original_image,
|
789 |
+
prompt,
|
790 |
+
resize_default,
|
791 |
+
aspect_ratio_name):
|
792 |
+
if original_image is None:
|
793 |
+
raise gr.Error('Please upload the input image')
|
794 |
+
if prompt is None:
|
795 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
796 |
+
|
797 |
+
## load mask
|
798 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
799 |
+
input_mask = np.array(alpha_mask)
|
800 |
+
|
801 |
+
# load example image
|
802 |
+
if isinstance(original_image, str):
|
803 |
+
original_image = input_image["background"]
|
804 |
+
|
805 |
+
if input_mask.max() == 0:
|
806 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
807 |
+
|
808 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
809 |
+
vlm_model,
|
810 |
+
original_image,
|
811 |
+
category,
|
812 |
+
prompt,
|
813 |
+
device)
|
814 |
+
# original mask: h,w,1 [0, 255]
|
815 |
+
original_mask = vlm_response_mask(
|
816 |
+
vlm_processor,
|
817 |
+
vlm_model,
|
818 |
+
category,
|
819 |
+
original_image,
|
820 |
+
prompt,
|
821 |
+
object_wait_for_edit,
|
822 |
+
sam,
|
823 |
+
sam_predictor,
|
824 |
+
sam_automask_generator,
|
825 |
+
groundingdino_model,
|
826 |
+
device).astype(np.uint8)
|
827 |
+
else:
|
828 |
+
original_mask = input_mask.astype(np.uint8)
|
829 |
+
category = None
|
830 |
+
|
831 |
+
## resize mask if needed
|
832 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
833 |
+
if output_w == "" or output_h == "":
|
834 |
+
output_h, output_w = original_image.shape[:2]
|
835 |
+
if resize_default:
|
836 |
+
short_side = min(output_w, output_h)
|
837 |
+
scale_ratio = 640 / short_side
|
838 |
+
output_w = int(output_w * scale_ratio)
|
839 |
+
output_h = int(output_h * scale_ratio)
|
840 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
841 |
+
original_image = np.array(original_image)
|
842 |
+
if input_mask is not None:
|
843 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
844 |
+
input_mask = np.array(input_mask)
|
845 |
+
if original_mask is not None:
|
846 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
847 |
+
original_mask = np.array(original_mask)
|
848 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
849 |
+
else:
|
850 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
851 |
+
pass
|
852 |
+
else:
|
853 |
+
if resize_default:
|
854 |
+
short_side = min(output_w, output_h)
|
855 |
+
scale_ratio = 640 / short_side
|
856 |
+
output_w = int(output_w * scale_ratio)
|
857 |
+
output_h = int(output_h * scale_ratio)
|
858 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
859 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
860 |
+
original_image = np.array(original_image)
|
861 |
+
if input_mask is not None:
|
862 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
863 |
+
input_mask = np.array(input_mask)
|
864 |
+
if original_mask is not None:
|
865 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
866 |
+
original_mask = np.array(original_mask)
|
867 |
+
|
868 |
+
|
869 |
+
if original_mask.ndim == 2:
|
870 |
+
original_mask = original_mask[:,:,None]
|
871 |
+
|
872 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
873 |
+
|
874 |
+
masked_image = original_image * (1 - (original_mask>0))
|
875 |
+
masked_image = masked_image.astype(np.uint8)
|
876 |
+
masked_image = Image.fromarray(masked_image)
|
877 |
+
|
878 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
879 |
+
|
880 |
+
|
881 |
+
def process_random_mask(input_image,
|
882 |
+
original_image,
|
883 |
+
original_mask,
|
884 |
+
resize_default,
|
885 |
+
aspect_ratio_name,
|
886 |
+
):
|
887 |
+
|
888 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
889 |
+
input_mask = np.asarray(alpha_mask)
|
890 |
+
|
891 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
892 |
+
if output_w == "" or output_h == "":
|
893 |
+
output_h, output_w = original_image.shape[:2]
|
894 |
+
if resize_default:
|
895 |
+
short_side = min(output_w, output_h)
|
896 |
+
scale_ratio = 640 / short_side
|
897 |
+
output_w = int(output_w * scale_ratio)
|
898 |
+
output_h = int(output_h * scale_ratio)
|
899 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
900 |
+
original_image = np.array(original_image)
|
901 |
+
if input_mask is not None:
|
902 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
903 |
+
input_mask = np.array(input_mask)
|
904 |
+
if original_mask is not None:
|
905 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
906 |
+
original_mask = np.array(original_mask)
|
907 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
908 |
+
else:
|
909 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
910 |
+
pass
|
911 |
+
else:
|
912 |
+
if resize_default:
|
913 |
+
short_side = min(output_w, output_h)
|
914 |
+
scale_ratio = 640 / short_side
|
915 |
+
output_w = int(output_w * scale_ratio)
|
916 |
+
output_h = int(output_h * scale_ratio)
|
917 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
918 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
919 |
+
original_image = np.array(original_image)
|
920 |
+
if input_mask is not None:
|
921 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
922 |
+
input_mask = np.array(input_mask)
|
923 |
+
if original_mask is not None:
|
924 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
925 |
+
original_mask = np.array(original_mask)
|
926 |
+
|
927 |
+
|
928 |
+
if input_mask.max() == 0:
|
929 |
+
original_mask = original_mask
|
930 |
+
else:
|
931 |
+
original_mask = input_mask
|
932 |
+
|
933 |
+
if original_mask is None:
|
934 |
+
raise gr.Error('Please generate mask first')
|
935 |
+
|
936 |
+
if original_mask.ndim == 2:
|
937 |
+
original_mask = original_mask[:,:,None]
|
938 |
+
|
939 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
940 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
941 |
+
|
942 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
943 |
+
|
944 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
945 |
+
masked_image = masked_image.astype(original_image.dtype)
|
946 |
+
masked_image = Image.fromarray(masked_image)
|
947 |
+
|
948 |
+
|
949 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
950 |
+
|
951 |
+
|
952 |
+
def process_dilation_mask(input_image,
|
953 |
+
original_image,
|
954 |
+
original_mask,
|
955 |
+
resize_default,
|
956 |
+
aspect_ratio_name,
|
957 |
+
dilation_size=20):
|
958 |
+
|
959 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
960 |
+
input_mask = np.asarray(alpha_mask)
|
961 |
+
|
962 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
963 |
+
if output_w == "" or output_h == "":
|
964 |
+
output_h, output_w = original_image.shape[:2]
|
965 |
+
if resize_default:
|
966 |
+
short_side = min(output_w, output_h)
|
967 |
+
scale_ratio = 640 / short_side
|
968 |
+
output_w = int(output_w * scale_ratio)
|
969 |
+
output_h = int(output_h * scale_ratio)
|
970 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
971 |
+
original_image = np.array(original_image)
|
972 |
+
if input_mask is not None:
|
973 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
974 |
+
input_mask = np.array(input_mask)
|
975 |
+
if original_mask is not None:
|
976 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
977 |
+
original_mask = np.array(original_mask)
|
978 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
979 |
+
else:
|
980 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
981 |
+
pass
|
982 |
+
else:
|
983 |
+
if resize_default:
|
984 |
+
short_side = min(output_w, output_h)
|
985 |
+
scale_ratio = 640 / short_side
|
986 |
+
output_w = int(output_w * scale_ratio)
|
987 |
+
output_h = int(output_h * scale_ratio)
|
988 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
989 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
990 |
+
original_image = np.array(original_image)
|
991 |
+
if input_mask is not None:
|
992 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
993 |
+
input_mask = np.array(input_mask)
|
994 |
+
if original_mask is not None:
|
995 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
996 |
+
original_mask = np.array(original_mask)
|
997 |
+
|
998 |
+
if input_mask.max() == 0:
|
999 |
+
original_mask = original_mask
|
1000 |
+
else:
|
1001 |
+
original_mask = input_mask
|
1002 |
+
|
1003 |
+
if original_mask is None:
|
1004 |
+
raise gr.Error('Please generate mask first')
|
1005 |
+
|
1006 |
+
if original_mask.ndim == 2:
|
1007 |
+
original_mask = original_mask[:,:,None]
|
1008 |
+
|
1009 |
+
dilation_type = np.random.choice(['square_dilation'])
|
1010 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1011 |
+
|
1012 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1013 |
+
|
1014 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1015 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1016 |
+
masked_image = Image.fromarray(masked_image)
|
1017 |
+
|
1018 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1019 |
+
|
1020 |
+
|
1021 |
+
def process_erosion_mask(input_image,
|
1022 |
+
original_image,
|
1023 |
+
original_mask,
|
1024 |
+
resize_default,
|
1025 |
+
aspect_ratio_name,
|
1026 |
+
dilation_size=20):
|
1027 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1028 |
+
input_mask = np.asarray(alpha_mask)
|
1029 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1030 |
+
if output_w == "" or output_h == "":
|
1031 |
+
output_h, output_w = original_image.shape[:2]
|
1032 |
+
if resize_default:
|
1033 |
+
short_side = min(output_w, output_h)
|
1034 |
+
scale_ratio = 640 / short_side
|
1035 |
+
output_w = int(output_w * scale_ratio)
|
1036 |
+
output_h = int(output_h * scale_ratio)
|
1037 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1038 |
+
original_image = np.array(original_image)
|
1039 |
+
if input_mask is not None:
|
1040 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1041 |
+
input_mask = np.array(input_mask)
|
1042 |
+
if original_mask is not None:
|
1043 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1044 |
+
original_mask = np.array(original_mask)
|
1045 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1046 |
+
else:
|
1047 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1048 |
+
pass
|
1049 |
+
else:
|
1050 |
+
if resize_default:
|
1051 |
+
short_side = min(output_w, output_h)
|
1052 |
+
scale_ratio = 640 / short_side
|
1053 |
+
output_w = int(output_w * scale_ratio)
|
1054 |
+
output_h = int(output_h * scale_ratio)
|
1055 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1056 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1057 |
+
original_image = np.array(original_image)
|
1058 |
+
if input_mask is not None:
|
1059 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1060 |
+
input_mask = np.array(input_mask)
|
1061 |
+
if original_mask is not None:
|
1062 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1063 |
+
original_mask = np.array(original_mask)
|
1064 |
+
|
1065 |
+
if input_mask.max() == 0:
|
1066 |
+
original_mask = original_mask
|
1067 |
+
else:
|
1068 |
+
original_mask = input_mask
|
1069 |
+
|
1070 |
+
if original_mask is None:
|
1071 |
+
raise gr.Error('Please generate mask first')
|
1072 |
+
|
1073 |
+
if original_mask.ndim == 2:
|
1074 |
+
original_mask = original_mask[:,:,None]
|
1075 |
+
|
1076 |
+
dilation_type = np.random.choice(['square_erosion'])
|
1077 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1078 |
+
|
1079 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1080 |
+
|
1081 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1082 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1083 |
+
masked_image = Image.fromarray(masked_image)
|
1084 |
+
|
1085 |
+
|
1086 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1087 |
+
|
1088 |
+
|
1089 |
+
def move_mask_left(input_image,
|
1090 |
+
original_image,
|
1091 |
+
original_mask,
|
1092 |
+
moving_pixels,
|
1093 |
+
resize_default,
|
1094 |
+
aspect_ratio_name):
|
1095 |
+
|
1096 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1097 |
+
input_mask = np.asarray(alpha_mask)
|
1098 |
+
|
1099 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1100 |
+
if output_w == "" or output_h == "":
|
1101 |
+
output_h, output_w = original_image.shape[:2]
|
1102 |
+
if resize_default:
|
1103 |
+
short_side = min(output_w, output_h)
|
1104 |
+
scale_ratio = 640 / short_side
|
1105 |
+
output_w = int(output_w * scale_ratio)
|
1106 |
+
output_h = int(output_h * scale_ratio)
|
1107 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1108 |
+
original_image = np.array(original_image)
|
1109 |
+
if input_mask is not None:
|
1110 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1111 |
+
input_mask = np.array(input_mask)
|
1112 |
+
if original_mask is not None:
|
1113 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1114 |
+
original_mask = np.array(original_mask)
|
1115 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1116 |
+
else:
|
1117 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1118 |
+
pass
|
1119 |
+
else:
|
1120 |
+
if resize_default:
|
1121 |
+
short_side = min(output_w, output_h)
|
1122 |
+
scale_ratio = 640 / short_side
|
1123 |
+
output_w = int(output_w * scale_ratio)
|
1124 |
+
output_h = int(output_h * scale_ratio)
|
1125 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1126 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1127 |
+
original_image = np.array(original_image)
|
1128 |
+
if input_mask is not None:
|
1129 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1130 |
+
input_mask = np.array(input_mask)
|
1131 |
+
if original_mask is not None:
|
1132 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1133 |
+
original_mask = np.array(original_mask)
|
1134 |
+
|
1135 |
+
if input_mask.max() == 0:
|
1136 |
+
original_mask = original_mask
|
1137 |
+
else:
|
1138 |
+
original_mask = input_mask
|
1139 |
+
|
1140 |
+
if original_mask is None:
|
1141 |
+
raise gr.Error('Please generate mask first')
|
1142 |
+
|
1143 |
+
if original_mask.ndim == 2:
|
1144 |
+
original_mask = original_mask[:,:,None]
|
1145 |
+
|
1146 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1147 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1148 |
+
|
1149 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1150 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1151 |
+
masked_image = Image.fromarray(masked_image)
|
1152 |
+
|
1153 |
+
if moved_mask.max() <= 1:
|
1154 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1155 |
+
original_mask = moved_mask
|
1156 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1157 |
+
|
1158 |
+
|
1159 |
+
def move_mask_right(input_image,
|
1160 |
+
original_image,
|
1161 |
+
original_mask,
|
1162 |
+
moving_pixels,
|
1163 |
+
resize_default,
|
1164 |
+
aspect_ratio_name):
|
1165 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1166 |
+
input_mask = np.asarray(alpha_mask)
|
1167 |
+
|
1168 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1169 |
+
if output_w == "" or output_h == "":
|
1170 |
+
output_h, output_w = original_image.shape[:2]
|
1171 |
+
if resize_default:
|
1172 |
+
short_side = min(output_w, output_h)
|
1173 |
+
scale_ratio = 640 / short_side
|
1174 |
+
output_w = int(output_w * scale_ratio)
|
1175 |
+
output_h = int(output_h * scale_ratio)
|
1176 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1177 |
+
original_image = np.array(original_image)
|
1178 |
+
if input_mask is not None:
|
1179 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1180 |
+
input_mask = np.array(input_mask)
|
1181 |
+
if original_mask is not None:
|
1182 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1183 |
+
original_mask = np.array(original_mask)
|
1184 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1185 |
+
else:
|
1186 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1187 |
+
pass
|
1188 |
+
else:
|
1189 |
+
if resize_default:
|
1190 |
+
short_side = min(output_w, output_h)
|
1191 |
+
scale_ratio = 640 / short_side
|
1192 |
+
output_w = int(output_w * scale_ratio)
|
1193 |
+
output_h = int(output_h * scale_ratio)
|
1194 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1195 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1196 |
+
original_image = np.array(original_image)
|
1197 |
+
if input_mask is not None:
|
1198 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1199 |
+
input_mask = np.array(input_mask)
|
1200 |
+
if original_mask is not None:
|
1201 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1202 |
+
original_mask = np.array(original_mask)
|
1203 |
+
|
1204 |
+
if input_mask.max() == 0:
|
1205 |
+
original_mask = original_mask
|
1206 |
+
else:
|
1207 |
+
original_mask = input_mask
|
1208 |
+
|
1209 |
+
if original_mask is None:
|
1210 |
+
raise gr.Error('Please generate mask first')
|
1211 |
+
|
1212 |
+
if original_mask.ndim == 2:
|
1213 |
+
original_mask = original_mask[:,:,None]
|
1214 |
+
|
1215 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1216 |
+
|
1217 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1218 |
+
|
1219 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1220 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1221 |
+
masked_image = Image.fromarray(masked_image)
|
1222 |
+
|
1223 |
+
|
1224 |
+
if moved_mask.max() <= 1:
|
1225 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1226 |
+
original_mask = moved_mask
|
1227 |
+
|
1228 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1229 |
+
|
1230 |
+
|
1231 |
+
def move_mask_up(input_image,
|
1232 |
+
original_image,
|
1233 |
+
original_mask,
|
1234 |
+
moving_pixels,
|
1235 |
+
resize_default,
|
1236 |
+
aspect_ratio_name):
|
1237 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1238 |
+
input_mask = np.asarray(alpha_mask)
|
1239 |
+
|
1240 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1241 |
+
if output_w == "" or output_h == "":
|
1242 |
+
output_h, output_w = original_image.shape[:2]
|
1243 |
+
if resize_default:
|
1244 |
+
short_side = min(output_w, output_h)
|
1245 |
+
scale_ratio = 640 / short_side
|
1246 |
+
output_w = int(output_w * scale_ratio)
|
1247 |
+
output_h = int(output_h * scale_ratio)
|
1248 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1249 |
+
original_image = np.array(original_image)
|
1250 |
+
if input_mask is not None:
|
1251 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1252 |
+
input_mask = np.array(input_mask)
|
1253 |
+
if original_mask is not None:
|
1254 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1255 |
+
original_mask = np.array(original_mask)
|
1256 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1257 |
+
else:
|
1258 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1259 |
+
pass
|
1260 |
+
else:
|
1261 |
+
if resize_default:
|
1262 |
+
short_side = min(output_w, output_h)
|
1263 |
+
scale_ratio = 640 / short_side
|
1264 |
+
output_w = int(output_w * scale_ratio)
|
1265 |
+
output_h = int(output_h * scale_ratio)
|
1266 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1267 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1268 |
+
original_image = np.array(original_image)
|
1269 |
+
if input_mask is not None:
|
1270 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1271 |
+
input_mask = np.array(input_mask)
|
1272 |
+
if original_mask is not None:
|
1273 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1274 |
+
original_mask = np.array(original_mask)
|
1275 |
+
|
1276 |
+
if input_mask.max() == 0:
|
1277 |
+
original_mask = original_mask
|
1278 |
+
else:
|
1279 |
+
original_mask = input_mask
|
1280 |
+
|
1281 |
+
if original_mask is None:
|
1282 |
+
raise gr.Error('Please generate mask first')
|
1283 |
+
|
1284 |
+
if original_mask.ndim == 2:
|
1285 |
+
original_mask = original_mask[:,:,None]
|
1286 |
+
|
1287 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1288 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1289 |
+
|
1290 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1291 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1292 |
+
masked_image = Image.fromarray(masked_image)
|
1293 |
+
|
1294 |
+
if moved_mask.max() <= 1:
|
1295 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1296 |
+
original_mask = moved_mask
|
1297 |
+
|
1298 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1299 |
+
|
1300 |
+
|
1301 |
+
def move_mask_down(input_image,
|
1302 |
+
original_image,
|
1303 |
+
original_mask,
|
1304 |
+
moving_pixels,
|
1305 |
+
resize_default,
|
1306 |
+
aspect_ratio_name):
|
1307 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1308 |
+
input_mask = np.asarray(alpha_mask)
|
1309 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1310 |
+
if output_w == "" or output_h == "":
|
1311 |
+
output_h, output_w = original_image.shape[:2]
|
1312 |
+
if resize_default:
|
1313 |
+
short_side = min(output_w, output_h)
|
1314 |
+
scale_ratio = 640 / short_side
|
1315 |
+
output_w = int(output_w * scale_ratio)
|
1316 |
+
output_h = int(output_h * scale_ratio)
|
1317 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1318 |
+
original_image = np.array(original_image)
|
1319 |
+
if input_mask is not None:
|
1320 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1321 |
+
input_mask = np.array(input_mask)
|
1322 |
+
if original_mask is not None:
|
1323 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1324 |
+
original_mask = np.array(original_mask)
|
1325 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1326 |
+
else:
|
1327 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1328 |
+
pass
|
1329 |
+
else:
|
1330 |
+
if resize_default:
|
1331 |
+
short_side = min(output_w, output_h)
|
1332 |
+
scale_ratio = 640 / short_side
|
1333 |
+
output_w = int(output_w * scale_ratio)
|
1334 |
+
output_h = int(output_h * scale_ratio)
|
1335 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1336 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1337 |
+
original_image = np.array(original_image)
|
1338 |
+
if input_mask is not None:
|
1339 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1340 |
+
input_mask = np.array(input_mask)
|
1341 |
+
if original_mask is not None:
|
1342 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1343 |
+
original_mask = np.array(original_mask)
|
1344 |
+
|
1345 |
+
if input_mask.max() == 0:
|
1346 |
+
original_mask = original_mask
|
1347 |
+
else:
|
1348 |
+
original_mask = input_mask
|
1349 |
+
|
1350 |
+
if original_mask is None:
|
1351 |
+
raise gr.Error('Please generate mask first')
|
1352 |
+
|
1353 |
+
if original_mask.ndim == 2:
|
1354 |
+
original_mask = original_mask[:,:,None]
|
1355 |
+
|
1356 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1357 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1358 |
+
|
1359 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1360 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1361 |
+
masked_image = Image.fromarray(masked_image)
|
1362 |
+
|
1363 |
+
if moved_mask.max() <= 1:
|
1364 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1365 |
+
original_mask = moved_mask
|
1366 |
+
|
1367 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1368 |
+
|
1369 |
+
|
1370 |
+
def invert_mask(input_image,
|
1371 |
+
original_image,
|
1372 |
+
original_mask,
|
1373 |
+
):
|
1374 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1375 |
+
input_mask = np.asarray(alpha_mask)
|
1376 |
+
if input_mask.max() == 0:
|
1377 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1378 |
+
else:
|
1379 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1380 |
+
|
1381 |
+
if original_mask is None:
|
1382 |
+
raise gr.Error('Please generate mask first')
|
1383 |
+
|
1384 |
+
original_mask = original_mask.squeeze()
|
1385 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1386 |
+
|
1387 |
+
if original_mask.ndim == 2:
|
1388 |
+
original_mask = original_mask[:,:,None]
|
1389 |
+
|
1390 |
+
if original_mask.max() <= 1:
|
1391 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1392 |
+
|
1393 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1394 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1395 |
+
masked_image = Image.fromarray(masked_image)
|
1396 |
+
|
1397 |
+
return [masked_image], [mask_image], original_mask, True
|
1398 |
+
|
1399 |
+
|
1400 |
+
def init_img(base,
|
1401 |
+
init_type,
|
1402 |
+
prompt,
|
1403 |
+
aspect_ratio,
|
1404 |
+
example_change_times
|
1405 |
+
):
|
1406 |
+
image_pil = base["background"].convert("RGB")
|
1407 |
+
original_image = np.array(image_pil)
|
1408 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1409 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1410 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1411 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1412 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1413 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1414 |
+
width, height = image_pil.size
|
1415 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1416 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1417 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1418 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1419 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1420 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1421 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1422 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1423 |
+
else:
|
1424 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1425 |
+
aspect_ratio = "Custom resolution"
|
1426 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1427 |
+
|
1428 |
+
|
1429 |
+
def reset_func(input_image,
|
1430 |
+
original_image,
|
1431 |
+
original_mask,
|
1432 |
+
prompt,
|
1433 |
+
target_prompt,
|
1434 |
+
):
|
1435 |
+
input_image = None
|
1436 |
+
original_image = None
|
1437 |
+
original_mask = None
|
1438 |
+
prompt = ''
|
1439 |
+
mask_gallery = []
|
1440 |
+
masked_gallery = []
|
1441 |
+
result_gallery = []
|
1442 |
+
target_prompt = ''
|
1443 |
+
if torch.cuda.is_available():
|
1444 |
+
torch.cuda.empty_cache()
|
1445 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1446 |
+
|
1447 |
+
|
1448 |
+
def update_example(example_type,
|
1449 |
+
prompt,
|
1450 |
+
example_change_times):
|
1451 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1452 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1453 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1454 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1455 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1456 |
+
width, height = image_pil.size
|
1457 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1458 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1459 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1460 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1461 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1462 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1463 |
+
|
1464 |
+
original_image = np.array(image_pil)
|
1465 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1466 |
+
aspect_ratio = "Custom resolution"
|
1467 |
+
example_change_times += 1
|
1468 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1469 |
+
|
1470 |
+
|
1471 |
+
block = gr.Blocks(
|
1472 |
+
theme=gr.themes.Soft(
|
1473 |
+
radius_size=gr.themes.sizes.radius_none,
|
1474 |
+
text_size=gr.themes.sizes.text_md
|
1475 |
+
)
|
1476 |
+
)
|
1477 |
+
with block as demo:
|
1478 |
+
with gr.Row():
|
1479 |
+
with gr.Column():
|
1480 |
+
gr.HTML(head)
|
1481 |
+
|
1482 |
+
gr.Markdown(descriptions)
|
1483 |
+
|
1484 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1485 |
+
with gr.Row(equal_height=True):
|
1486 |
+
gr.Markdown(instructions)
|
1487 |
+
|
1488 |
+
original_image = gr.State(value=None)
|
1489 |
+
original_mask = gr.State(value=None)
|
1490 |
+
category = gr.State(value=None)
|
1491 |
+
status = gr.State(value=None)
|
1492 |
+
invert_mask_state = gr.State(value=False)
|
1493 |
+
example_change_times = gr.State(value=0)
|
1494 |
+
|
1495 |
+
|
1496 |
+
with gr.Row():
|
1497 |
+
with gr.Column():
|
1498 |
+
with gr.Row():
|
1499 |
+
input_image = gr.ImageEditor(
|
1500 |
+
label="Input Image",
|
1501 |
+
type="pil",
|
1502 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1503 |
+
layers = False,
|
1504 |
+
interactive=True,
|
1505 |
+
height=1024,
|
1506 |
+
sources=["upload"],
|
1507 |
+
placeholder="Please click here or the icon below to upload the image.",
|
1508 |
+
)
|
1509 |
+
|
1510 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1511 |
+
run_button = gr.Button("💫 Run")
|
1512 |
+
|
1513 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1514 |
+
with gr.Group():
|
1515 |
+
with gr.Row():
|
1516 |
+
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1517 |
+
|
1518 |
+
GPT4o_KEY_submit = gr.Button("Submit and Verify")
|
1519 |
+
|
1520 |
+
|
1521 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1522 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1523 |
+
|
1524 |
+
with gr.Row():
|
1525 |
+
mask_button = gr.Button("Generate Mask")
|
1526 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1527 |
+
|
1528 |
+
|
1529 |
+
with gr.Row():
|
1530 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1531 |
+
|
1532 |
+
target_prompt = gr.Text(
|
1533 |
+
label="Input Target Prompt",
|
1534 |
+
max_lines=5,
|
1535 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1536 |
+
value='',
|
1537 |
+
lines=2
|
1538 |
+
)
|
1539 |
+
|
1540 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1541 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1542 |
+
negative_prompt = gr.Text(
|
1543 |
+
label="Negative Prompt",
|
1544 |
+
max_lines=5,
|
1545 |
+
placeholder="Please input your negative prompt",
|
1546 |
+
value='ugly, low quality',lines=1
|
1547 |
+
)
|
1548 |
+
|
1549 |
+
control_strength = gr.Slider(
|
1550 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1551 |
+
)
|
1552 |
+
with gr.Group():
|
1553 |
+
seed = gr.Slider(
|
1554 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1555 |
+
)
|
1556 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1557 |
+
|
1558 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1559 |
+
|
1560 |
+
|
1561 |
+
num_samples = gr.Slider(
|
1562 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1563 |
+
)
|
1564 |
+
|
1565 |
+
with gr.Group():
|
1566 |
+
with gr.Row():
|
1567 |
+
guidance_scale = gr.Slider(
|
1568 |
+
label="Guidance scale",
|
1569 |
+
minimum=1,
|
1570 |
+
maximum=12,
|
1571 |
+
step=0.1,
|
1572 |
+
value=7.5,
|
1573 |
+
)
|
1574 |
+
num_inference_steps = gr.Slider(
|
1575 |
+
label="Number of inference steps",
|
1576 |
+
minimum=1,
|
1577 |
+
maximum=50,
|
1578 |
+
step=1,
|
1579 |
+
value=50,
|
1580 |
+
)
|
1581 |
+
|
1582 |
+
|
1583 |
+
with gr.Column():
|
1584 |
+
with gr.Row():
|
1585 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1586 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1587 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1588 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1589 |
+
|
1590 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1591 |
+
dilation_size = gr.Slider(
|
1592 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1593 |
+
)
|
1594 |
+
with gr.Row():
|
1595 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1596 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1597 |
+
|
1598 |
+
moving_pixels = gr.Slider(
|
1599 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1600 |
+
)
|
1601 |
+
with gr.Row():
|
1602 |
+
move_left_button = gr.Button("Move Left")
|
1603 |
+
move_right_button = gr.Button("Move Right")
|
1604 |
+
with gr.Row():
|
1605 |
+
move_up_button = gr.Button("Move Up")
|
1606 |
+
move_down_button = gr.Button("Move Down")
|
1607 |
+
|
1608 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1609 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1610 |
+
|
1611 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1612 |
+
|
1613 |
+
reset_button = gr.Button("Reset")
|
1614 |
+
|
1615 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1616 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1617 |
+
|
1618 |
+
|
1619 |
+
|
1620 |
+
with gr.Row():
|
1621 |
+
example = gr.Examples(
|
1622 |
+
label="Quick Example",
|
1623 |
+
examples=EXAMPLES,
|
1624 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1625 |
+
examples_per_page=10,
|
1626 |
+
cache_examples=False,
|
1627 |
+
)
|
1628 |
+
|
1629 |
+
|
1630 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1631 |
+
with gr.Row(equal_height=True):
|
1632 |
+
gr.Markdown(tips)
|
1633 |
+
|
1634 |
+
with gr.Row():
|
1635 |
+
gr.Markdown(citation)
|
1636 |
+
|
1637 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1638 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1639 |
+
input_image.upload(
|
1640 |
+
init_img,
|
1641 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1642 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1643 |
+
)
|
1644 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1645 |
+
|
1646 |
+
## vlm and base model dropdown
|
1647 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1648 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1649 |
+
|
1650 |
+
|
1651 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1652 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1653 |
+
|
1654 |
+
|
1655 |
+
ips=[input_image,
|
1656 |
+
original_image,
|
1657 |
+
original_mask,
|
1658 |
+
prompt,
|
1659 |
+
negative_prompt,
|
1660 |
+
control_strength,
|
1661 |
+
seed,
|
1662 |
+
randomize_seed,
|
1663 |
+
guidance_scale,
|
1664 |
+
num_inference_steps,
|
1665 |
+
num_samples,
|
1666 |
+
blending,
|
1667 |
+
category,
|
1668 |
+
target_prompt,
|
1669 |
+
resize_default,
|
1670 |
+
aspect_ratio,
|
1671 |
+
invert_mask_state]
|
1672 |
+
|
1673 |
+
## run brushedit
|
1674 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1675 |
+
|
1676 |
+
## mask func
|
1677 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1678 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1679 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1680 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1681 |
+
|
1682 |
+
## move mask func
|
1683 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1684 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1685 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1686 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1687 |
+
|
1688 |
+
## prompt func
|
1689 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1690 |
+
|
1691 |
+
## reset func
|
1692 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1693 |
+
|
1694 |
+
## if have a localhost access error, try to use the following code
|
1695 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1696 |
+
# demo.launch()
|
brushedit_app_315_1.py
ADDED
@@ -0,0 +1,1624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
# base_openai_url = ""
|
39 |
+
|
40 |
+
#### Description ####
|
41 |
+
logo = r"""
|
42 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
43 |
+
"""
|
44 |
+
head = r"""
|
45 |
+
<div style="text-align: center;">
|
46 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
47 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
48 |
+
<a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
49 |
+
<a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
50 |
+
<a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
51 |
+
|
52 |
+
</div>
|
53 |
+
</br>
|
54 |
+
</div>
|
55 |
+
"""
|
56 |
+
descriptions = r"""
|
57 |
+
Official Gradio Demo
|
58 |
+
"""
|
59 |
+
|
60 |
+
instructions = r"""
|
61 |
+
等待补充
|
62 |
+
"""
|
63 |
+
|
64 |
+
tips = r"""
|
65 |
+
等待补充
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
citation = r"""
|
72 |
+
等待补充
|
73 |
+
"""
|
74 |
+
|
75 |
+
# - - - - - examples - - - - - #
|
76 |
+
EXAMPLES = [
|
77 |
+
|
78 |
+
[
|
79 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
80 |
+
"add a magic hat on frog head.",
|
81 |
+
642087011,
|
82 |
+
"frog",
|
83 |
+
"frog",
|
84 |
+
True,
|
85 |
+
False,
|
86 |
+
"GPT4-o (Highly Recommended)"
|
87 |
+
],
|
88 |
+
[
|
89 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
90 |
+
"replace the background to ancient China.",
|
91 |
+
648464818,
|
92 |
+
"chinese_girl",
|
93 |
+
"chinese_girl",
|
94 |
+
True,
|
95 |
+
False,
|
96 |
+
"GPT4-o (Highly Recommended)"
|
97 |
+
],
|
98 |
+
[
|
99 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
100 |
+
"remove the deer.",
|
101 |
+
648464818,
|
102 |
+
"angel_christmas",
|
103 |
+
"angel_christmas",
|
104 |
+
False,
|
105 |
+
False,
|
106 |
+
"GPT4-o (Highly Recommended)"
|
107 |
+
],
|
108 |
+
[
|
109 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
110 |
+
"add a wreath on head.",
|
111 |
+
648464818,
|
112 |
+
"sunflower_girl",
|
113 |
+
"sunflower_girl",
|
114 |
+
True,
|
115 |
+
False,
|
116 |
+
"GPT4-o (Highly Recommended)"
|
117 |
+
],
|
118 |
+
[
|
119 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
120 |
+
"add a butterfly fairy.",
|
121 |
+
648464818,
|
122 |
+
"girl_on_sun",
|
123 |
+
"girl_on_sun",
|
124 |
+
True,
|
125 |
+
False,
|
126 |
+
"GPT4-o (Highly Recommended)"
|
127 |
+
],
|
128 |
+
[
|
129 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
130 |
+
"remove the christmas hat.",
|
131 |
+
642087011,
|
132 |
+
"spider_man_rm",
|
133 |
+
"spider_man_rm",
|
134 |
+
False,
|
135 |
+
False,
|
136 |
+
"GPT4-o (Highly Recommended)"
|
137 |
+
],
|
138 |
+
[
|
139 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
140 |
+
"remove the flower.",
|
141 |
+
642087011,
|
142 |
+
"anime_flower",
|
143 |
+
"anime_flower",
|
144 |
+
False,
|
145 |
+
False,
|
146 |
+
"GPT4-o (Highly Recommended)"
|
147 |
+
],
|
148 |
+
[
|
149 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
150 |
+
"replace the clothes to a delicated floral skirt.",
|
151 |
+
648464818,
|
152 |
+
"chenduling",
|
153 |
+
"chenduling",
|
154 |
+
True,
|
155 |
+
False,
|
156 |
+
"GPT4-o (Highly Recommended)"
|
157 |
+
],
|
158 |
+
[
|
159 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
160 |
+
"make the hedgehog in Italy.",
|
161 |
+
648464818,
|
162 |
+
"hedgehog_rp_bg",
|
163 |
+
"hedgehog_rp_bg",
|
164 |
+
True,
|
165 |
+
False,
|
166 |
+
"GPT4-o (Highly Recommended)"
|
167 |
+
],
|
168 |
+
|
169 |
+
]
|
170 |
+
|
171 |
+
INPUT_IMAGE_PATH = {
|
172 |
+
"frog": "./assets/frog/frog.jpeg",
|
173 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
174 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
175 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
176 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
177 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
178 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
179 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
180 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
181 |
+
}
|
182 |
+
MASK_IMAGE_PATH = {
|
183 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
184 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
185 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
186 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
187 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
188 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
189 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
190 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
191 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
192 |
+
}
|
193 |
+
MASKED_IMAGE_PATH = {
|
194 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
195 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
196 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
197 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
198 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
199 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
200 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
201 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
202 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
203 |
+
}
|
204 |
+
OUTPUT_IMAGE_PATH = {
|
205 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
206 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
207 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
208 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
209 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
210 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
211 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
212 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
213 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
214 |
+
}
|
215 |
+
|
216 |
+
|
217 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
218 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
219 |
+
|
220 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
221 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
222 |
+
BASE_MODELS = list(base_models_template.keys())
|
223 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
224 |
+
|
225 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
226 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
227 |
+
|
228 |
+
device = "cuda:0"
|
229 |
+
torch_dtype = torch.bfloat16
|
230 |
+
|
231 |
+
## init device
|
232 |
+
# try:
|
233 |
+
# if torch.cuda.is_available():
|
234 |
+
# device = "cuda"
|
235 |
+
# print("device = cuda")
|
236 |
+
# elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
237 |
+
# device = "mps"
|
238 |
+
# print("device = mps")
|
239 |
+
# else:
|
240 |
+
# device = "cpu"
|
241 |
+
# print("device = cpu")
|
242 |
+
# except:
|
243 |
+
# device = "cpu"
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
# download hf models
|
248 |
+
BrushEdit_path = "models/"
|
249 |
+
if not os.path.exists(BrushEdit_path):
|
250 |
+
BrushEdit_path = snapshot_download(
|
251 |
+
repo_id="TencentARC/BrushEdit",
|
252 |
+
local_dir=BrushEdit_path,
|
253 |
+
token=os.getenv("HF_TOKEN"),
|
254 |
+
)
|
255 |
+
|
256 |
+
## init default VLM
|
257 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
258 |
+
if vlm_processor != "" and vlm_model != "":
|
259 |
+
vlm_model.to(device)
|
260 |
+
else:
|
261 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
262 |
+
|
263 |
+
|
264 |
+
## init base model
|
265 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
266 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
267 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
268 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
269 |
+
|
270 |
+
|
271 |
+
# input brushnetX ckpt path
|
272 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
273 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
274 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
275 |
+
)
|
276 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
277 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
278 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
279 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
280 |
+
pipe.enable_model_cpu_offload()
|
281 |
+
|
282 |
+
|
283 |
+
## init SAM
|
284 |
+
sam = build_sam(checkpoint=sam_path)
|
285 |
+
sam.to(device=device)
|
286 |
+
sam_predictor = SamPredictor(sam)
|
287 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
288 |
+
|
289 |
+
## init groundingdino_model
|
290 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
291 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
292 |
+
|
293 |
+
## Ordinary function
|
294 |
+
def crop_and_resize(image: Image.Image,
|
295 |
+
target_width: int,
|
296 |
+
target_height: int) -> Image.Image:
|
297 |
+
"""
|
298 |
+
Crops and resizes an image while preserving the aspect ratio.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
302 |
+
target_width (int): Target width of the output image.
|
303 |
+
target_height (int): Target height of the output image.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
Image.Image: Cropped and resized image.
|
307 |
+
"""
|
308 |
+
# Original dimensions
|
309 |
+
original_width, original_height = image.size
|
310 |
+
original_aspect = original_width / original_height
|
311 |
+
target_aspect = target_width / target_height
|
312 |
+
|
313 |
+
# Calculate crop box to maintain aspect ratio
|
314 |
+
if original_aspect > target_aspect:
|
315 |
+
# Crop horizontally
|
316 |
+
new_width = int(original_height * target_aspect)
|
317 |
+
new_height = original_height
|
318 |
+
left = (original_width - new_width) / 2
|
319 |
+
top = 0
|
320 |
+
right = left + new_width
|
321 |
+
bottom = original_height
|
322 |
+
else:
|
323 |
+
# Crop vertically
|
324 |
+
new_width = original_width
|
325 |
+
new_height = int(original_width / target_aspect)
|
326 |
+
left = 0
|
327 |
+
top = (original_height - new_height) / 2
|
328 |
+
right = original_width
|
329 |
+
bottom = top + new_height
|
330 |
+
|
331 |
+
# Crop and resize
|
332 |
+
cropped_image = image.crop((left, top, right, bottom))
|
333 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
334 |
+
return resized_image
|
335 |
+
|
336 |
+
|
337 |
+
## Ordinary function
|
338 |
+
def resize(image: Image.Image,
|
339 |
+
target_width: int,
|
340 |
+
target_height: int) -> Image.Image:
|
341 |
+
"""
|
342 |
+
Crops and resizes an image while preserving the aspect ratio.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
346 |
+
target_width (int): Target width of the output image.
|
347 |
+
target_height (int): Target height of the output image.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
Image.Image: Cropped and resized image.
|
351 |
+
"""
|
352 |
+
# Original dimensions
|
353 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
354 |
+
return resized_image
|
355 |
+
|
356 |
+
|
357 |
+
def move_mask_func(mask, direction, units):
|
358 |
+
binary_mask = mask.squeeze()>0
|
359 |
+
rows, cols = binary_mask.shape
|
360 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
361 |
+
|
362 |
+
if direction == 'down':
|
363 |
+
# move down
|
364 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
365 |
+
|
366 |
+
elif direction == 'up':
|
367 |
+
# move up
|
368 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
369 |
+
|
370 |
+
elif direction == 'right':
|
371 |
+
# move left
|
372 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
373 |
+
|
374 |
+
elif direction == 'left':
|
375 |
+
# move right
|
376 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
377 |
+
|
378 |
+
return moved_mask
|
379 |
+
|
380 |
+
|
381 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
382 |
+
# Randomly select the size of dilation
|
383 |
+
binary_mask = mask.squeeze()>0
|
384 |
+
|
385 |
+
if dilation_type == 'square_dilation':
|
386 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
387 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
388 |
+
elif dilation_type == 'square_erosion':
|
389 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
390 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
391 |
+
elif dilation_type == 'bounding_box':
|
392 |
+
# find the most left top and left bottom point
|
393 |
+
rows, cols = np.where(binary_mask)
|
394 |
+
if len(rows) == 0 or len(cols) == 0:
|
395 |
+
return mask # return original mask if no valid points
|
396 |
+
|
397 |
+
min_row = np.min(rows)
|
398 |
+
max_row = np.max(rows)
|
399 |
+
min_col = np.min(cols)
|
400 |
+
max_col = np.max(cols)
|
401 |
+
|
402 |
+
# create a bounding box
|
403 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
404 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
405 |
+
|
406 |
+
elif dilation_type == 'bounding_ellipse':
|
407 |
+
# find the most left top and left bottom point
|
408 |
+
rows, cols = np.where(binary_mask)
|
409 |
+
if len(rows) == 0 or len(cols) == 0:
|
410 |
+
return mask # return original mask if no valid points
|
411 |
+
|
412 |
+
min_row = np.min(rows)
|
413 |
+
max_row = np.max(rows)
|
414 |
+
min_col = np.min(cols)
|
415 |
+
max_col = np.max(cols)
|
416 |
+
|
417 |
+
# calculate the center and axis length of the ellipse
|
418 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
419 |
+
a = (max_col - min_col) // 2 # half long axis
|
420 |
+
b = (max_row - min_row) // 2 # half short axis
|
421 |
+
|
422 |
+
# create a bounding ellipse
|
423 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
424 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
425 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
426 |
+
dilated_mask[ellipse_mask] = True
|
427 |
+
else:
|
428 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
429 |
+
|
430 |
+
# use binary dilation
|
431 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
432 |
+
return dilated_mask
|
433 |
+
|
434 |
+
|
435 |
+
## Gradio component function
|
436 |
+
def update_vlm_model(vlm_name):
|
437 |
+
global vlm_model, vlm_processor
|
438 |
+
if vlm_model is not None:
|
439 |
+
del vlm_model
|
440 |
+
torch.cuda.empty_cache()
|
441 |
+
|
442 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
443 |
+
|
444 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
445 |
+
if vlm_type == "llava-next":
|
446 |
+
if vlm_processor != "" and vlm_model != "":
|
447 |
+
vlm_model.to(device)
|
448 |
+
return vlm_model_dropdown
|
449 |
+
else:
|
450 |
+
if os.path.exists(vlm_local_path):
|
451 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
452 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
453 |
+
else:
|
454 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
455 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
456 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
457 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
458 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
459 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
460 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
461 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
462 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
463 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
464 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
465 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
466 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
467 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
468 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
469 |
+
elif vlm_type == "qwen2-vl":
|
470 |
+
if vlm_processor != "" and vlm_model != "":
|
471 |
+
vlm_model.to(device)
|
472 |
+
return vlm_model_dropdown
|
473 |
+
else:
|
474 |
+
if os.path.exists(vlm_local_path):
|
475 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
476 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
477 |
+
else:
|
478 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
479 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
480 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
481 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
482 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
483 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
484 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
485 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
486 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
487 |
+
elif vlm_type == "openai":
|
488 |
+
pass
|
489 |
+
return "success"
|
490 |
+
|
491 |
+
|
492 |
+
def update_base_model(base_model_name):
|
493 |
+
global pipe
|
494 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
495 |
+
if pipe is not None:
|
496 |
+
del pipe
|
497 |
+
torch.cuda.empty_cache()
|
498 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
499 |
+
if pipe != "":
|
500 |
+
pipe.to(device)
|
501 |
+
else:
|
502 |
+
if os.path.exists(base_model_path):
|
503 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
504 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
505 |
+
)
|
506 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
507 |
+
pipe.enable_model_cpu_offload()
|
508 |
+
else:
|
509 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
510 |
+
return "success"
|
511 |
+
|
512 |
+
|
513 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
514 |
+
global vlm_model, vlm_processor
|
515 |
+
if vlm_model is not None:
|
516 |
+
del vlm_model
|
517 |
+
torch.cuda.empty_cache()
|
518 |
+
try:
|
519 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY)
|
520 |
+
vlm_processor = ""
|
521 |
+
response = vlm_model.chat.completions.create(
|
522 |
+
model="gpt-4o-2024-08-06",
|
523 |
+
messages=[
|
524 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
525 |
+
{"role": "user", "content": "Say this is a test"}
|
526 |
+
]
|
527 |
+
)
|
528 |
+
response_str = response.choices[0].message.content
|
529 |
+
|
530 |
+
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
|
531 |
+
except Exception as e:
|
532 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
533 |
+
|
534 |
+
|
535 |
+
|
536 |
+
def process(input_image,
|
537 |
+
original_image,
|
538 |
+
original_mask,
|
539 |
+
prompt,
|
540 |
+
negative_prompt,
|
541 |
+
control_strength,
|
542 |
+
seed,
|
543 |
+
randomize_seed,
|
544 |
+
guidance_scale,
|
545 |
+
num_inference_steps,
|
546 |
+
num_samples,
|
547 |
+
blending,
|
548 |
+
category,
|
549 |
+
target_prompt,
|
550 |
+
resize_default,
|
551 |
+
aspect_ratio_name,
|
552 |
+
invert_mask_state):
|
553 |
+
if original_image is None:
|
554 |
+
if input_image is None:
|
555 |
+
raise gr.Error('Please upload the input image')
|
556 |
+
else:
|
557 |
+
image_pil = input_image["background"].convert("RGB")
|
558 |
+
original_image = np.array(image_pil)
|
559 |
+
if prompt is None or prompt == "":
|
560 |
+
if target_prompt is None or target_prompt == "":
|
561 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
562 |
+
|
563 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
564 |
+
input_mask = np.asarray(alpha_mask)
|
565 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
566 |
+
if output_w == "" or output_h == "":
|
567 |
+
output_h, output_w = original_image.shape[:2]
|
568 |
+
|
569 |
+
if resize_default:
|
570 |
+
short_side = min(output_w, output_h)
|
571 |
+
scale_ratio = 640 / short_side
|
572 |
+
output_w = int(output_w * scale_ratio)
|
573 |
+
output_h = int(output_h * scale_ratio)
|
574 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
575 |
+
original_image = np.array(original_image)
|
576 |
+
if input_mask is not None:
|
577 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
578 |
+
input_mask = np.array(input_mask)
|
579 |
+
if original_mask is not None:
|
580 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
581 |
+
original_mask = np.array(original_mask)
|
582 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
583 |
+
else:
|
584 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
585 |
+
pass
|
586 |
+
else:
|
587 |
+
if resize_default:
|
588 |
+
short_side = min(output_w, output_h)
|
589 |
+
scale_ratio = 640 / short_side
|
590 |
+
output_w = int(output_w * scale_ratio)
|
591 |
+
output_h = int(output_h * scale_ratio)
|
592 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
593 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
594 |
+
original_image = np.array(original_image)
|
595 |
+
if input_mask is not None:
|
596 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
597 |
+
input_mask = np.array(input_mask)
|
598 |
+
if original_mask is not None:
|
599 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
600 |
+
original_mask = np.array(original_mask)
|
601 |
+
|
602 |
+
if invert_mask_state:
|
603 |
+
original_mask = original_mask
|
604 |
+
else:
|
605 |
+
if input_mask.max() == 0:
|
606 |
+
original_mask = original_mask
|
607 |
+
else:
|
608 |
+
original_mask = input_mask
|
609 |
+
|
610 |
+
|
611 |
+
## inpainting directly if target_prompt is not None
|
612 |
+
if category is not None:
|
613 |
+
pass
|
614 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
615 |
+
pass
|
616 |
+
else:
|
617 |
+
try:
|
618 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
619 |
+
except Exception as e:
|
620 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
621 |
+
|
622 |
+
|
623 |
+
if original_mask is not None:
|
624 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
625 |
+
else:
|
626 |
+
try:
|
627 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
628 |
+
vlm_processor,
|
629 |
+
vlm_model,
|
630 |
+
original_image,
|
631 |
+
category,
|
632 |
+
prompt,
|
633 |
+
device)
|
634 |
+
|
635 |
+
original_mask = vlm_response_mask(vlm_processor,
|
636 |
+
vlm_model,
|
637 |
+
category,
|
638 |
+
original_image,
|
639 |
+
prompt,
|
640 |
+
object_wait_for_edit,
|
641 |
+
sam,
|
642 |
+
sam_predictor,
|
643 |
+
sam_automask_generator,
|
644 |
+
groundingdino_model,
|
645 |
+
device).astype(np.uint8)
|
646 |
+
except Exception as e:
|
647 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
648 |
+
|
649 |
+
if original_mask.ndim == 2:
|
650 |
+
original_mask = original_mask[:,:,None]
|
651 |
+
|
652 |
+
|
653 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
654 |
+
prompt_after_apply_instruction = target_prompt
|
655 |
+
|
656 |
+
else:
|
657 |
+
try:
|
658 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
659 |
+
vlm_processor,
|
660 |
+
vlm_model,
|
661 |
+
original_image,
|
662 |
+
prompt,
|
663 |
+
device)
|
664 |
+
except Exception as e:
|
665 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
666 |
+
|
667 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
668 |
+
|
669 |
+
|
670 |
+
with torch.autocast(device):
|
671 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
672 |
+
prompt_after_apply_instruction,
|
673 |
+
original_mask,
|
674 |
+
original_image,
|
675 |
+
generator,
|
676 |
+
num_inference_steps,
|
677 |
+
guidance_scale,
|
678 |
+
control_strength,
|
679 |
+
negative_prompt,
|
680 |
+
num_samples,
|
681 |
+
blending)
|
682 |
+
original_image = np.array(init_image_np)
|
683 |
+
masked_image = original_image * (1 - (mask_np>0))
|
684 |
+
masked_image = masked_image.astype(np.uint8)
|
685 |
+
masked_image = Image.fromarray(masked_image)
|
686 |
+
# Save the images (optional)
|
687 |
+
# import uuid
|
688 |
+
# uuid = str(uuid.uuid4())
|
689 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
690 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
691 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
692 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
693 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
694 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
695 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
696 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
697 |
+
|
698 |
+
|
699 |
+
def generate_target_prompt(input_image,
|
700 |
+
original_image,
|
701 |
+
prompt):
|
702 |
+
# load example image
|
703 |
+
if isinstance(original_image, str):
|
704 |
+
original_image = input_image
|
705 |
+
|
706 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
707 |
+
vlm_processor,
|
708 |
+
vlm_model,
|
709 |
+
original_image,
|
710 |
+
prompt,
|
711 |
+
device)
|
712 |
+
return prompt_after_apply_instruction
|
713 |
+
|
714 |
+
|
715 |
+
def process_mask(input_image,
|
716 |
+
original_image,
|
717 |
+
prompt,
|
718 |
+
resize_default,
|
719 |
+
aspect_ratio_name):
|
720 |
+
if original_image is None:
|
721 |
+
raise gr.Error('Please upload the input image')
|
722 |
+
if prompt is None:
|
723 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
724 |
+
|
725 |
+
## load mask
|
726 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
727 |
+
input_mask = np.array(alpha_mask)
|
728 |
+
|
729 |
+
# load example image
|
730 |
+
if isinstance(original_image, str):
|
731 |
+
original_image = input_image["background"]
|
732 |
+
|
733 |
+
if input_mask.max() == 0:
|
734 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
735 |
+
|
736 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
737 |
+
vlm_model,
|
738 |
+
original_image,
|
739 |
+
category,
|
740 |
+
prompt,
|
741 |
+
device)
|
742 |
+
# original mask: h,w,1 [0, 255]
|
743 |
+
original_mask = vlm_response_mask(
|
744 |
+
vlm_processor,
|
745 |
+
vlm_model,
|
746 |
+
category,
|
747 |
+
original_image,
|
748 |
+
prompt,
|
749 |
+
object_wait_for_edit,
|
750 |
+
sam,
|
751 |
+
sam_predictor,
|
752 |
+
sam_automask_generator,
|
753 |
+
groundingdino_model,
|
754 |
+
device).astype(np.uint8)
|
755 |
+
else:
|
756 |
+
original_mask = input_mask.astype(np.uint8)
|
757 |
+
category = None
|
758 |
+
|
759 |
+
## resize mask if needed
|
760 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
761 |
+
if output_w == "" or output_h == "":
|
762 |
+
output_h, output_w = original_image.shape[:2]
|
763 |
+
if resize_default:
|
764 |
+
short_side = min(output_w, output_h)
|
765 |
+
scale_ratio = 640 / short_side
|
766 |
+
output_w = int(output_w * scale_ratio)
|
767 |
+
output_h = int(output_h * scale_ratio)
|
768 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
769 |
+
original_image = np.array(original_image)
|
770 |
+
if input_mask is not None:
|
771 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
772 |
+
input_mask = np.array(input_mask)
|
773 |
+
if original_mask is not None:
|
774 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
775 |
+
original_mask = np.array(original_mask)
|
776 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
777 |
+
else:
|
778 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
779 |
+
pass
|
780 |
+
else:
|
781 |
+
if resize_default:
|
782 |
+
short_side = min(output_w, output_h)
|
783 |
+
scale_ratio = 640 / short_side
|
784 |
+
output_w = int(output_w * scale_ratio)
|
785 |
+
output_h = int(output_h * scale_ratio)
|
786 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
787 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
788 |
+
original_image = np.array(original_image)
|
789 |
+
if input_mask is not None:
|
790 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
791 |
+
input_mask = np.array(input_mask)
|
792 |
+
if original_mask is not None:
|
793 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
794 |
+
original_mask = np.array(original_mask)
|
795 |
+
|
796 |
+
|
797 |
+
if original_mask.ndim == 2:
|
798 |
+
original_mask = original_mask[:,:,None]
|
799 |
+
|
800 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
801 |
+
|
802 |
+
masked_image = original_image * (1 - (original_mask>0))
|
803 |
+
masked_image = masked_image.astype(np.uint8)
|
804 |
+
masked_image = Image.fromarray(masked_image)
|
805 |
+
|
806 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
807 |
+
|
808 |
+
|
809 |
+
def process_random_mask(input_image,
|
810 |
+
original_image,
|
811 |
+
original_mask,
|
812 |
+
resize_default,
|
813 |
+
aspect_ratio_name,
|
814 |
+
):
|
815 |
+
|
816 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
817 |
+
input_mask = np.asarray(alpha_mask)
|
818 |
+
|
819 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
820 |
+
if output_w == "" or output_h == "":
|
821 |
+
output_h, output_w = original_image.shape[:2]
|
822 |
+
if resize_default:
|
823 |
+
short_side = min(output_w, output_h)
|
824 |
+
scale_ratio = 640 / short_side
|
825 |
+
output_w = int(output_w * scale_ratio)
|
826 |
+
output_h = int(output_h * scale_ratio)
|
827 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
828 |
+
original_image = np.array(original_image)
|
829 |
+
if input_mask is not None:
|
830 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
831 |
+
input_mask = np.array(input_mask)
|
832 |
+
if original_mask is not None:
|
833 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
834 |
+
original_mask = np.array(original_mask)
|
835 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
836 |
+
else:
|
837 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
838 |
+
pass
|
839 |
+
else:
|
840 |
+
if resize_default:
|
841 |
+
short_side = min(output_w, output_h)
|
842 |
+
scale_ratio = 640 / short_side
|
843 |
+
output_w = int(output_w * scale_ratio)
|
844 |
+
output_h = int(output_h * scale_ratio)
|
845 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
846 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
847 |
+
original_image = np.array(original_image)
|
848 |
+
if input_mask is not None:
|
849 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
850 |
+
input_mask = np.array(input_mask)
|
851 |
+
if original_mask is not None:
|
852 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
853 |
+
original_mask = np.array(original_mask)
|
854 |
+
|
855 |
+
|
856 |
+
if input_mask.max() == 0:
|
857 |
+
original_mask = original_mask
|
858 |
+
else:
|
859 |
+
original_mask = input_mask
|
860 |
+
|
861 |
+
if original_mask is None:
|
862 |
+
raise gr.Error('Please generate mask first')
|
863 |
+
|
864 |
+
if original_mask.ndim == 2:
|
865 |
+
original_mask = original_mask[:,:,None]
|
866 |
+
|
867 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
868 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
869 |
+
|
870 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
871 |
+
|
872 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
873 |
+
masked_image = masked_image.astype(original_image.dtype)
|
874 |
+
masked_image = Image.fromarray(masked_image)
|
875 |
+
|
876 |
+
|
877 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
878 |
+
|
879 |
+
|
880 |
+
def process_dilation_mask(input_image,
|
881 |
+
original_image,
|
882 |
+
original_mask,
|
883 |
+
resize_default,
|
884 |
+
aspect_ratio_name,
|
885 |
+
dilation_size=20):
|
886 |
+
|
887 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
888 |
+
input_mask = np.asarray(alpha_mask)
|
889 |
+
|
890 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
891 |
+
if output_w == "" or output_h == "":
|
892 |
+
output_h, output_w = original_image.shape[:2]
|
893 |
+
if resize_default:
|
894 |
+
short_side = min(output_w, output_h)
|
895 |
+
scale_ratio = 640 / short_side
|
896 |
+
output_w = int(output_w * scale_ratio)
|
897 |
+
output_h = int(output_h * scale_ratio)
|
898 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
899 |
+
original_image = np.array(original_image)
|
900 |
+
if input_mask is not None:
|
901 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
902 |
+
input_mask = np.array(input_mask)
|
903 |
+
if original_mask is not None:
|
904 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
905 |
+
original_mask = np.array(original_mask)
|
906 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
907 |
+
else:
|
908 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
909 |
+
pass
|
910 |
+
else:
|
911 |
+
if resize_default:
|
912 |
+
short_side = min(output_w, output_h)
|
913 |
+
scale_ratio = 640 / short_side
|
914 |
+
output_w = int(output_w * scale_ratio)
|
915 |
+
output_h = int(output_h * scale_ratio)
|
916 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
917 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
918 |
+
original_image = np.array(original_image)
|
919 |
+
if input_mask is not None:
|
920 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
921 |
+
input_mask = np.array(input_mask)
|
922 |
+
if original_mask is not None:
|
923 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
924 |
+
original_mask = np.array(original_mask)
|
925 |
+
|
926 |
+
if input_mask.max() == 0:
|
927 |
+
original_mask = original_mask
|
928 |
+
else:
|
929 |
+
original_mask = input_mask
|
930 |
+
|
931 |
+
if original_mask is None:
|
932 |
+
raise gr.Error('Please generate mask first')
|
933 |
+
|
934 |
+
if original_mask.ndim == 2:
|
935 |
+
original_mask = original_mask[:,:,None]
|
936 |
+
|
937 |
+
dilation_type = np.random.choice(['square_dilation'])
|
938 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
939 |
+
|
940 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
941 |
+
|
942 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
943 |
+
masked_image = masked_image.astype(original_image.dtype)
|
944 |
+
masked_image = Image.fromarray(masked_image)
|
945 |
+
|
946 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
947 |
+
|
948 |
+
|
949 |
+
def process_erosion_mask(input_image,
|
950 |
+
original_image,
|
951 |
+
original_mask,
|
952 |
+
resize_default,
|
953 |
+
aspect_ratio_name,
|
954 |
+
dilation_size=20):
|
955 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
956 |
+
input_mask = np.asarray(alpha_mask)
|
957 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
958 |
+
if output_w == "" or output_h == "":
|
959 |
+
output_h, output_w = original_image.shape[:2]
|
960 |
+
if resize_default:
|
961 |
+
short_side = min(output_w, output_h)
|
962 |
+
scale_ratio = 640 / short_side
|
963 |
+
output_w = int(output_w * scale_ratio)
|
964 |
+
output_h = int(output_h * scale_ratio)
|
965 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
966 |
+
original_image = np.array(original_image)
|
967 |
+
if input_mask is not None:
|
968 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
969 |
+
input_mask = np.array(input_mask)
|
970 |
+
if original_mask is not None:
|
971 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
972 |
+
original_mask = np.array(original_mask)
|
973 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
974 |
+
else:
|
975 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
976 |
+
pass
|
977 |
+
else:
|
978 |
+
if resize_default:
|
979 |
+
short_side = min(output_w, output_h)
|
980 |
+
scale_ratio = 640 / short_side
|
981 |
+
output_w = int(output_w * scale_ratio)
|
982 |
+
output_h = int(output_h * scale_ratio)
|
983 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
984 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
985 |
+
original_image = np.array(original_image)
|
986 |
+
if input_mask is not None:
|
987 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
988 |
+
input_mask = np.array(input_mask)
|
989 |
+
if original_mask is not None:
|
990 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
991 |
+
original_mask = np.array(original_mask)
|
992 |
+
|
993 |
+
if input_mask.max() == 0:
|
994 |
+
original_mask = original_mask
|
995 |
+
else:
|
996 |
+
original_mask = input_mask
|
997 |
+
|
998 |
+
if original_mask is None:
|
999 |
+
raise gr.Error('Please generate mask first')
|
1000 |
+
|
1001 |
+
if original_mask.ndim == 2:
|
1002 |
+
original_mask = original_mask[:,:,None]
|
1003 |
+
|
1004 |
+
dilation_type = np.random.choice(['square_erosion'])
|
1005 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1006 |
+
|
1007 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1008 |
+
|
1009 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1010 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1011 |
+
masked_image = Image.fromarray(masked_image)
|
1012 |
+
|
1013 |
+
|
1014 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1015 |
+
|
1016 |
+
|
1017 |
+
def move_mask_left(input_image,
|
1018 |
+
original_image,
|
1019 |
+
original_mask,
|
1020 |
+
moving_pixels,
|
1021 |
+
resize_default,
|
1022 |
+
aspect_ratio_name):
|
1023 |
+
|
1024 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1025 |
+
input_mask = np.asarray(alpha_mask)
|
1026 |
+
|
1027 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1028 |
+
if output_w == "" or output_h == "":
|
1029 |
+
output_h, output_w = original_image.shape[:2]
|
1030 |
+
if resize_default:
|
1031 |
+
short_side = min(output_w, output_h)
|
1032 |
+
scale_ratio = 640 / short_side
|
1033 |
+
output_w = int(output_w * scale_ratio)
|
1034 |
+
output_h = int(output_h * scale_ratio)
|
1035 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1036 |
+
original_image = np.array(original_image)
|
1037 |
+
if input_mask is not None:
|
1038 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1039 |
+
input_mask = np.array(input_mask)
|
1040 |
+
if original_mask is not None:
|
1041 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1042 |
+
original_mask = np.array(original_mask)
|
1043 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1044 |
+
else:
|
1045 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1046 |
+
pass
|
1047 |
+
else:
|
1048 |
+
if resize_default:
|
1049 |
+
short_side = min(output_w, output_h)
|
1050 |
+
scale_ratio = 640 / short_side
|
1051 |
+
output_w = int(output_w * scale_ratio)
|
1052 |
+
output_h = int(output_h * scale_ratio)
|
1053 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1054 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1055 |
+
original_image = np.array(original_image)
|
1056 |
+
if input_mask is not None:
|
1057 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1058 |
+
input_mask = np.array(input_mask)
|
1059 |
+
if original_mask is not None:
|
1060 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1061 |
+
original_mask = np.array(original_mask)
|
1062 |
+
|
1063 |
+
if input_mask.max() == 0:
|
1064 |
+
original_mask = original_mask
|
1065 |
+
else:
|
1066 |
+
original_mask = input_mask
|
1067 |
+
|
1068 |
+
if original_mask is None:
|
1069 |
+
raise gr.Error('Please generate mask first')
|
1070 |
+
|
1071 |
+
if original_mask.ndim == 2:
|
1072 |
+
original_mask = original_mask[:,:,None]
|
1073 |
+
|
1074 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1075 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1076 |
+
|
1077 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1078 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1079 |
+
masked_image = Image.fromarray(masked_image)
|
1080 |
+
|
1081 |
+
if moved_mask.max() <= 1:
|
1082 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1083 |
+
original_mask = moved_mask
|
1084 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1085 |
+
|
1086 |
+
|
1087 |
+
def move_mask_right(input_image,
|
1088 |
+
original_image,
|
1089 |
+
original_mask,
|
1090 |
+
moving_pixels,
|
1091 |
+
resize_default,
|
1092 |
+
aspect_ratio_name):
|
1093 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1094 |
+
input_mask = np.asarray(alpha_mask)
|
1095 |
+
|
1096 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1097 |
+
if output_w == "" or output_h == "":
|
1098 |
+
output_h, output_w = original_image.shape[:2]
|
1099 |
+
if resize_default:
|
1100 |
+
short_side = min(output_w, output_h)
|
1101 |
+
scale_ratio = 640 / short_side
|
1102 |
+
output_w = int(output_w * scale_ratio)
|
1103 |
+
output_h = int(output_h * scale_ratio)
|
1104 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1105 |
+
original_image = np.array(original_image)
|
1106 |
+
if input_mask is not None:
|
1107 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1108 |
+
input_mask = np.array(input_mask)
|
1109 |
+
if original_mask is not None:
|
1110 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1111 |
+
original_mask = np.array(original_mask)
|
1112 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1113 |
+
else:
|
1114 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1115 |
+
pass
|
1116 |
+
else:
|
1117 |
+
if resize_default:
|
1118 |
+
short_side = min(output_w, output_h)
|
1119 |
+
scale_ratio = 640 / short_side
|
1120 |
+
output_w = int(output_w * scale_ratio)
|
1121 |
+
output_h = int(output_h * scale_ratio)
|
1122 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1123 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1124 |
+
original_image = np.array(original_image)
|
1125 |
+
if input_mask is not None:
|
1126 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1127 |
+
input_mask = np.array(input_mask)
|
1128 |
+
if original_mask is not None:
|
1129 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1130 |
+
original_mask = np.array(original_mask)
|
1131 |
+
|
1132 |
+
if input_mask.max() == 0:
|
1133 |
+
original_mask = original_mask
|
1134 |
+
else:
|
1135 |
+
original_mask = input_mask
|
1136 |
+
|
1137 |
+
if original_mask is None:
|
1138 |
+
raise gr.Error('Please generate mask first')
|
1139 |
+
|
1140 |
+
if original_mask.ndim == 2:
|
1141 |
+
original_mask = original_mask[:,:,None]
|
1142 |
+
|
1143 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1144 |
+
|
1145 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1146 |
+
|
1147 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1148 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1149 |
+
masked_image = Image.fromarray(masked_image)
|
1150 |
+
|
1151 |
+
|
1152 |
+
if moved_mask.max() <= 1:
|
1153 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1154 |
+
original_mask = moved_mask
|
1155 |
+
|
1156 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1157 |
+
|
1158 |
+
|
1159 |
+
def move_mask_up(input_image,
|
1160 |
+
original_image,
|
1161 |
+
original_mask,
|
1162 |
+
moving_pixels,
|
1163 |
+
resize_default,
|
1164 |
+
aspect_ratio_name):
|
1165 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1166 |
+
input_mask = np.asarray(alpha_mask)
|
1167 |
+
|
1168 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1169 |
+
if output_w == "" or output_h == "":
|
1170 |
+
output_h, output_w = original_image.shape[:2]
|
1171 |
+
if resize_default:
|
1172 |
+
short_side = min(output_w, output_h)
|
1173 |
+
scale_ratio = 640 / short_side
|
1174 |
+
output_w = int(output_w * scale_ratio)
|
1175 |
+
output_h = int(output_h * scale_ratio)
|
1176 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1177 |
+
original_image = np.array(original_image)
|
1178 |
+
if input_mask is not None:
|
1179 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1180 |
+
input_mask = np.array(input_mask)
|
1181 |
+
if original_mask is not None:
|
1182 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1183 |
+
original_mask = np.array(original_mask)
|
1184 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1185 |
+
else:
|
1186 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1187 |
+
pass
|
1188 |
+
else:
|
1189 |
+
if resize_default:
|
1190 |
+
short_side = min(output_w, output_h)
|
1191 |
+
scale_ratio = 640 / short_side
|
1192 |
+
output_w = int(output_w * scale_ratio)
|
1193 |
+
output_h = int(output_h * scale_ratio)
|
1194 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1195 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1196 |
+
original_image = np.array(original_image)
|
1197 |
+
if input_mask is not None:
|
1198 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1199 |
+
input_mask = np.array(input_mask)
|
1200 |
+
if original_mask is not None:
|
1201 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1202 |
+
original_mask = np.array(original_mask)
|
1203 |
+
|
1204 |
+
if input_mask.max() == 0:
|
1205 |
+
original_mask = original_mask
|
1206 |
+
else:
|
1207 |
+
original_mask = input_mask
|
1208 |
+
|
1209 |
+
if original_mask is None:
|
1210 |
+
raise gr.Error('Please generate mask first')
|
1211 |
+
|
1212 |
+
if original_mask.ndim == 2:
|
1213 |
+
original_mask = original_mask[:,:,None]
|
1214 |
+
|
1215 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1216 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1217 |
+
|
1218 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1219 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1220 |
+
masked_image = Image.fromarray(masked_image)
|
1221 |
+
|
1222 |
+
if moved_mask.max() <= 1:
|
1223 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1224 |
+
original_mask = moved_mask
|
1225 |
+
|
1226 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1227 |
+
|
1228 |
+
|
1229 |
+
def move_mask_down(input_image,
|
1230 |
+
original_image,
|
1231 |
+
original_mask,
|
1232 |
+
moving_pixels,
|
1233 |
+
resize_default,
|
1234 |
+
aspect_ratio_name):
|
1235 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1236 |
+
input_mask = np.asarray(alpha_mask)
|
1237 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1238 |
+
if output_w == "" or output_h == "":
|
1239 |
+
output_h, output_w = original_image.shape[:2]
|
1240 |
+
if resize_default:
|
1241 |
+
short_side = min(output_w, output_h)
|
1242 |
+
scale_ratio = 640 / short_side
|
1243 |
+
output_w = int(output_w * scale_ratio)
|
1244 |
+
output_h = int(output_h * scale_ratio)
|
1245 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1246 |
+
original_image = np.array(original_image)
|
1247 |
+
if input_mask is not None:
|
1248 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1249 |
+
input_mask = np.array(input_mask)
|
1250 |
+
if original_mask is not None:
|
1251 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1252 |
+
original_mask = np.array(original_mask)
|
1253 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1254 |
+
else:
|
1255 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1256 |
+
pass
|
1257 |
+
else:
|
1258 |
+
if resize_default:
|
1259 |
+
short_side = min(output_w, output_h)
|
1260 |
+
scale_ratio = 640 / short_side
|
1261 |
+
output_w = int(output_w * scale_ratio)
|
1262 |
+
output_h = int(output_h * scale_ratio)
|
1263 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1264 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1265 |
+
original_image = np.array(original_image)
|
1266 |
+
if input_mask is not None:
|
1267 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1268 |
+
input_mask = np.array(input_mask)
|
1269 |
+
if original_mask is not None:
|
1270 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1271 |
+
original_mask = np.array(original_mask)
|
1272 |
+
|
1273 |
+
if input_mask.max() == 0:
|
1274 |
+
original_mask = original_mask
|
1275 |
+
else:
|
1276 |
+
original_mask = input_mask
|
1277 |
+
|
1278 |
+
if original_mask is None:
|
1279 |
+
raise gr.Error('Please generate mask first')
|
1280 |
+
|
1281 |
+
if original_mask.ndim == 2:
|
1282 |
+
original_mask = original_mask[:,:,None]
|
1283 |
+
|
1284 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1285 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1286 |
+
|
1287 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1288 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1289 |
+
masked_image = Image.fromarray(masked_image)
|
1290 |
+
|
1291 |
+
if moved_mask.max() <= 1:
|
1292 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1293 |
+
original_mask = moved_mask
|
1294 |
+
|
1295 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1296 |
+
|
1297 |
+
|
1298 |
+
def invert_mask(input_image,
|
1299 |
+
original_image,
|
1300 |
+
original_mask,
|
1301 |
+
):
|
1302 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1303 |
+
input_mask = np.asarray(alpha_mask)
|
1304 |
+
if input_mask.max() == 0:
|
1305 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1306 |
+
else:
|
1307 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1308 |
+
|
1309 |
+
if original_mask is None:
|
1310 |
+
raise gr.Error('Please generate mask first')
|
1311 |
+
|
1312 |
+
original_mask = original_mask.squeeze()
|
1313 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1314 |
+
|
1315 |
+
if original_mask.ndim == 2:
|
1316 |
+
original_mask = original_mask[:,:,None]
|
1317 |
+
|
1318 |
+
if original_mask.max() <= 1:
|
1319 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1320 |
+
|
1321 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1322 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1323 |
+
masked_image = Image.fromarray(masked_image)
|
1324 |
+
|
1325 |
+
return [masked_image], [mask_image], original_mask, True
|
1326 |
+
|
1327 |
+
|
1328 |
+
def init_img(base,
|
1329 |
+
init_type,
|
1330 |
+
prompt,
|
1331 |
+
aspect_ratio,
|
1332 |
+
example_change_times
|
1333 |
+
):
|
1334 |
+
image_pil = base["background"].convert("RGB")
|
1335 |
+
original_image = np.array(image_pil)
|
1336 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1337 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1338 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1339 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1340 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1341 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1342 |
+
width, height = image_pil.size
|
1343 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1344 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1345 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1346 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1347 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1348 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1349 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1350 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1351 |
+
else:
|
1352 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1353 |
+
aspect_ratio = "Custom resolution"
|
1354 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1355 |
+
|
1356 |
+
|
1357 |
+
def reset_func(input_image,
|
1358 |
+
original_image,
|
1359 |
+
original_mask,
|
1360 |
+
prompt,
|
1361 |
+
target_prompt,
|
1362 |
+
):
|
1363 |
+
input_image = None
|
1364 |
+
original_image = None
|
1365 |
+
original_mask = None
|
1366 |
+
prompt = ''
|
1367 |
+
mask_gallery = []
|
1368 |
+
masked_gallery = []
|
1369 |
+
result_gallery = []
|
1370 |
+
target_prompt = ''
|
1371 |
+
if torch.cuda.is_available():
|
1372 |
+
torch.cuda.empty_cache()
|
1373 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1374 |
+
|
1375 |
+
|
1376 |
+
def update_example(example_type,
|
1377 |
+
prompt,
|
1378 |
+
example_change_times):
|
1379 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1380 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1381 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1382 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1383 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1384 |
+
width, height = image_pil.size
|
1385 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1386 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1387 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1388 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1389 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1390 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1391 |
+
|
1392 |
+
original_image = np.array(image_pil)
|
1393 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1394 |
+
aspect_ratio = "Custom resolution"
|
1395 |
+
example_change_times += 1
|
1396 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1397 |
+
|
1398 |
+
|
1399 |
+
block = gr.Blocks(
|
1400 |
+
theme=gr.themes.Soft(
|
1401 |
+
radius_size=gr.themes.sizes.radius_none,
|
1402 |
+
text_size=gr.themes.sizes.text_md
|
1403 |
+
)
|
1404 |
+
)
|
1405 |
+
with block as demo:
|
1406 |
+
with gr.Row():
|
1407 |
+
with gr.Column():
|
1408 |
+
gr.HTML(head)
|
1409 |
+
|
1410 |
+
gr.Markdown(descriptions)
|
1411 |
+
|
1412 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1413 |
+
with gr.Row(equal_height=True):
|
1414 |
+
gr.Markdown(instructions)
|
1415 |
+
|
1416 |
+
original_image = gr.State(value=None)
|
1417 |
+
original_mask = gr.State(value=None)
|
1418 |
+
category = gr.State(value=None)
|
1419 |
+
status = gr.State(value=None)
|
1420 |
+
invert_mask_state = gr.State(value=False)
|
1421 |
+
example_change_times = gr.State(value=0)
|
1422 |
+
|
1423 |
+
|
1424 |
+
with gr.Row():
|
1425 |
+
with gr.Column():
|
1426 |
+
with gr.Row():
|
1427 |
+
input_image = gr.ImageEditor(
|
1428 |
+
label="Input Image",
|
1429 |
+
type="pil",
|
1430 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1431 |
+
layers = False,
|
1432 |
+
interactive=True,
|
1433 |
+
height=1024,
|
1434 |
+
sources=["upload"],
|
1435 |
+
placeholder="Please click here or the icon below to upload the image.",
|
1436 |
+
)
|
1437 |
+
|
1438 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1439 |
+
run_button = gr.Button("💫 Run")
|
1440 |
+
|
1441 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1442 |
+
with gr.Group():
|
1443 |
+
with gr.Row():
|
1444 |
+
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1445 |
+
|
1446 |
+
GPT4o_KEY_submit = gr.Button("Submit and Verify")
|
1447 |
+
|
1448 |
+
|
1449 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1450 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1451 |
+
|
1452 |
+
with gr.Row():
|
1453 |
+
mask_button = gr.Button("Generate Mask")
|
1454 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1455 |
+
|
1456 |
+
|
1457 |
+
with gr.Row():
|
1458 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1459 |
+
|
1460 |
+
target_prompt = gr.Text(
|
1461 |
+
label="Input Target Prompt",
|
1462 |
+
max_lines=5,
|
1463 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1464 |
+
value='',
|
1465 |
+
lines=2
|
1466 |
+
)
|
1467 |
+
|
1468 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1469 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1470 |
+
negative_prompt = gr.Text(
|
1471 |
+
label="Negative Prompt",
|
1472 |
+
max_lines=5,
|
1473 |
+
placeholder="Please input your negative prompt",
|
1474 |
+
value='ugly, low quality',lines=1
|
1475 |
+
)
|
1476 |
+
|
1477 |
+
control_strength = gr.Slider(
|
1478 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1479 |
+
)
|
1480 |
+
with gr.Group():
|
1481 |
+
seed = gr.Slider(
|
1482 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1483 |
+
)
|
1484 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1485 |
+
|
1486 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1487 |
+
|
1488 |
+
|
1489 |
+
num_samples = gr.Slider(
|
1490 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1491 |
+
)
|
1492 |
+
|
1493 |
+
with gr.Group():
|
1494 |
+
with gr.Row():
|
1495 |
+
guidance_scale = gr.Slider(
|
1496 |
+
label="Guidance scale",
|
1497 |
+
minimum=1,
|
1498 |
+
maximum=12,
|
1499 |
+
step=0.1,
|
1500 |
+
value=7.5,
|
1501 |
+
)
|
1502 |
+
num_inference_steps = gr.Slider(
|
1503 |
+
label="Number of inference steps",
|
1504 |
+
minimum=1,
|
1505 |
+
maximum=50,
|
1506 |
+
step=1,
|
1507 |
+
value=50,
|
1508 |
+
)
|
1509 |
+
|
1510 |
+
|
1511 |
+
with gr.Column():
|
1512 |
+
with gr.Row():
|
1513 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1514 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1515 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1516 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1517 |
+
|
1518 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1519 |
+
dilation_size = gr.Slider(
|
1520 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1521 |
+
)
|
1522 |
+
with gr.Row():
|
1523 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1524 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1525 |
+
|
1526 |
+
moving_pixels = gr.Slider(
|
1527 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1528 |
+
)
|
1529 |
+
with gr.Row():
|
1530 |
+
move_left_button = gr.Button("Move Left")
|
1531 |
+
move_right_button = gr.Button("Move Right")
|
1532 |
+
with gr.Row():
|
1533 |
+
move_up_button = gr.Button("Move Up")
|
1534 |
+
move_down_button = gr.Button("Move Down")
|
1535 |
+
|
1536 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1537 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1538 |
+
|
1539 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1540 |
+
|
1541 |
+
reset_button = gr.Button("Reset")
|
1542 |
+
|
1543 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1544 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1545 |
+
|
1546 |
+
|
1547 |
+
|
1548 |
+
with gr.Row():
|
1549 |
+
example = gr.Examples(
|
1550 |
+
label="Quick Example",
|
1551 |
+
examples=EXAMPLES,
|
1552 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1553 |
+
examples_per_page=10,
|
1554 |
+
cache_examples=False,
|
1555 |
+
)
|
1556 |
+
|
1557 |
+
|
1558 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1559 |
+
with gr.Row(equal_height=True):
|
1560 |
+
gr.Markdown(tips)
|
1561 |
+
|
1562 |
+
with gr.Row():
|
1563 |
+
gr.Markdown(citation)
|
1564 |
+
|
1565 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1566 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1567 |
+
input_image.upload(
|
1568 |
+
init_img,
|
1569 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1570 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1571 |
+
)
|
1572 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1573 |
+
|
1574 |
+
## vlm and base model dropdown
|
1575 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1576 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1577 |
+
|
1578 |
+
|
1579 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1580 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1581 |
+
|
1582 |
+
|
1583 |
+
ips=[input_image,
|
1584 |
+
original_image,
|
1585 |
+
original_mask,
|
1586 |
+
prompt,
|
1587 |
+
negative_prompt,
|
1588 |
+
control_strength,
|
1589 |
+
seed,
|
1590 |
+
randomize_seed,
|
1591 |
+
guidance_scale,
|
1592 |
+
num_inference_steps,
|
1593 |
+
num_samples,
|
1594 |
+
blending,
|
1595 |
+
category,
|
1596 |
+
target_prompt,
|
1597 |
+
resize_default,
|
1598 |
+
aspect_ratio,
|
1599 |
+
invert_mask_state]
|
1600 |
+
|
1601 |
+
## run brushedit
|
1602 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1603 |
+
|
1604 |
+
## mask func
|
1605 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1606 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1607 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1608 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1609 |
+
|
1610 |
+
## move mask func
|
1611 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1612 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1613 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1614 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1615 |
+
|
1616 |
+
## prompt func
|
1617 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1618 |
+
|
1619 |
+
## reset func
|
1620 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1621 |
+
|
1622 |
+
## if have a localhost access error, try to use the following code
|
1623 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1624 |
+
# demo.launch()
|
brushedit_app_315_2.py
ADDED
@@ -0,0 +1,1627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
# base_openai_url = ""
|
39 |
+
|
40 |
+
#### Description ####
|
41 |
+
logo = r"""
|
42 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
43 |
+
"""
|
44 |
+
head = r"""
|
45 |
+
<div style="text-align: center;">
|
46 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
47 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
48 |
+
<a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
49 |
+
<a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
50 |
+
<a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
51 |
+
|
52 |
+
</div>
|
53 |
+
</br>
|
54 |
+
</div>
|
55 |
+
"""
|
56 |
+
descriptions = r"""
|
57 |
+
Official Gradio Demo
|
58 |
+
"""
|
59 |
+
|
60 |
+
instructions = r"""
|
61 |
+
等待补充
|
62 |
+
"""
|
63 |
+
|
64 |
+
tips = r"""
|
65 |
+
等待补充
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
citation = r"""
|
72 |
+
等待补充
|
73 |
+
"""
|
74 |
+
|
75 |
+
# - - - - - examples - - - - - #
|
76 |
+
EXAMPLES = [
|
77 |
+
|
78 |
+
[
|
79 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
80 |
+
"add a magic hat on frog head.",
|
81 |
+
642087011,
|
82 |
+
"frog",
|
83 |
+
"frog",
|
84 |
+
True,
|
85 |
+
False,
|
86 |
+
"GPT4-o (Highly Recommended)"
|
87 |
+
],
|
88 |
+
[
|
89 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
90 |
+
"replace the background to ancient China.",
|
91 |
+
648464818,
|
92 |
+
"chinese_girl",
|
93 |
+
"chinese_girl",
|
94 |
+
True,
|
95 |
+
False,
|
96 |
+
"GPT4-o (Highly Recommended)"
|
97 |
+
],
|
98 |
+
[
|
99 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
100 |
+
"remove the deer.",
|
101 |
+
648464818,
|
102 |
+
"angel_christmas",
|
103 |
+
"angel_christmas",
|
104 |
+
False,
|
105 |
+
False,
|
106 |
+
"GPT4-o (Highly Recommended)"
|
107 |
+
],
|
108 |
+
[
|
109 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
110 |
+
"add a wreath on head.",
|
111 |
+
648464818,
|
112 |
+
"sunflower_girl",
|
113 |
+
"sunflower_girl",
|
114 |
+
True,
|
115 |
+
False,
|
116 |
+
"GPT4-o (Highly Recommended)"
|
117 |
+
],
|
118 |
+
[
|
119 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
120 |
+
"add a butterfly fairy.",
|
121 |
+
648464818,
|
122 |
+
"girl_on_sun",
|
123 |
+
"girl_on_sun",
|
124 |
+
True,
|
125 |
+
False,
|
126 |
+
"GPT4-o (Highly Recommended)"
|
127 |
+
],
|
128 |
+
[
|
129 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
130 |
+
"remove the christmas hat.",
|
131 |
+
642087011,
|
132 |
+
"spider_man_rm",
|
133 |
+
"spider_man_rm",
|
134 |
+
False,
|
135 |
+
False,
|
136 |
+
"GPT4-o (Highly Recommended)"
|
137 |
+
],
|
138 |
+
[
|
139 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
140 |
+
"remove the flower.",
|
141 |
+
642087011,
|
142 |
+
"anime_flower",
|
143 |
+
"anime_flower",
|
144 |
+
False,
|
145 |
+
False,
|
146 |
+
"GPT4-o (Highly Recommended)"
|
147 |
+
],
|
148 |
+
[
|
149 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
150 |
+
"replace the clothes to a delicated floral skirt.",
|
151 |
+
648464818,
|
152 |
+
"chenduling",
|
153 |
+
"chenduling",
|
154 |
+
True,
|
155 |
+
False,
|
156 |
+
"GPT4-o (Highly Recommended)"
|
157 |
+
],
|
158 |
+
[
|
159 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
160 |
+
"make the hedgehog in Italy.",
|
161 |
+
648464818,
|
162 |
+
"hedgehog_rp_bg",
|
163 |
+
"hedgehog_rp_bg",
|
164 |
+
True,
|
165 |
+
False,
|
166 |
+
"GPT4-o (Highly Recommended)"
|
167 |
+
],
|
168 |
+
|
169 |
+
]
|
170 |
+
|
171 |
+
INPUT_IMAGE_PATH = {
|
172 |
+
"frog": "./assets/frog/frog.jpeg",
|
173 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
174 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
175 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
176 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
177 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
178 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
179 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
180 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
181 |
+
}
|
182 |
+
MASK_IMAGE_PATH = {
|
183 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
184 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
185 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
186 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
187 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
188 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
189 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
190 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
191 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
192 |
+
}
|
193 |
+
MASKED_IMAGE_PATH = {
|
194 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
195 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
196 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
197 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
198 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
199 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
200 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
201 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
202 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
203 |
+
}
|
204 |
+
OUTPUT_IMAGE_PATH = {
|
205 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
206 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
207 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
208 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
209 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
210 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
211 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
212 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
213 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
214 |
+
}
|
215 |
+
|
216 |
+
|
217 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
218 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
219 |
+
|
220 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
221 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
222 |
+
BASE_MODELS = list(base_models_template.keys())
|
223 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
224 |
+
|
225 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
226 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
227 |
+
|
228 |
+
## init device
|
229 |
+
try:
|
230 |
+
if torch.cuda.is_available():
|
231 |
+
device = "cuda:0"
|
232 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
233 |
+
device = "mps"
|
234 |
+
else:
|
235 |
+
device = "cpu"
|
236 |
+
except:
|
237 |
+
device = "cpu"
|
238 |
+
|
239 |
+
## init torch dtype
|
240 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
241 |
+
torch_dtype = torch.bfloat16
|
242 |
+
else:
|
243 |
+
torch_dtype = torch.float16
|
244 |
+
|
245 |
+
if device == "mps":
|
246 |
+
torch_dtype = torch.float16
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
# download hf models
|
251 |
+
BrushEdit_path = "models/"
|
252 |
+
if not os.path.exists(BrushEdit_path):
|
253 |
+
BrushEdit_path = snapshot_download(
|
254 |
+
repo_id="TencentARC/BrushEdit",
|
255 |
+
local_dir=BrushEdit_path,
|
256 |
+
token=os.getenv("HF_TOKEN"),
|
257 |
+
)
|
258 |
+
|
259 |
+
## init default VLM
|
260 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
261 |
+
if vlm_processor != "" and vlm_model != "":
|
262 |
+
vlm_model.to(device)
|
263 |
+
else:
|
264 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
265 |
+
|
266 |
+
|
267 |
+
## init base model
|
268 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
269 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
270 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
271 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
272 |
+
|
273 |
+
|
274 |
+
# input brushnetX ckpt path
|
275 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
276 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
277 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
278 |
+
)
|
279 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
280 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
281 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
282 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
283 |
+
pipe.enable_model_cpu_offload()
|
284 |
+
|
285 |
+
|
286 |
+
## init SAM
|
287 |
+
sam = build_sam(checkpoint=sam_path)
|
288 |
+
sam.to(device=device)
|
289 |
+
sam_predictor = SamPredictor(sam)
|
290 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
291 |
+
|
292 |
+
## init groundingdino_model
|
293 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
294 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
295 |
+
|
296 |
+
## Ordinary function
|
297 |
+
def crop_and_resize(image: Image.Image,
|
298 |
+
target_width: int,
|
299 |
+
target_height: int) -> Image.Image:
|
300 |
+
"""
|
301 |
+
Crops and resizes an image while preserving the aspect ratio.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
305 |
+
target_width (int): Target width of the output image.
|
306 |
+
target_height (int): Target height of the output image.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
Image.Image: Cropped and resized image.
|
310 |
+
"""
|
311 |
+
# Original dimensions
|
312 |
+
original_width, original_height = image.size
|
313 |
+
original_aspect = original_width / original_height
|
314 |
+
target_aspect = target_width / target_height
|
315 |
+
|
316 |
+
# Calculate crop box to maintain aspect ratio
|
317 |
+
if original_aspect > target_aspect:
|
318 |
+
# Crop horizontally
|
319 |
+
new_width = int(original_height * target_aspect)
|
320 |
+
new_height = original_height
|
321 |
+
left = (original_width - new_width) / 2
|
322 |
+
top = 0
|
323 |
+
right = left + new_width
|
324 |
+
bottom = original_height
|
325 |
+
else:
|
326 |
+
# Crop vertically
|
327 |
+
new_width = original_width
|
328 |
+
new_height = int(original_width / target_aspect)
|
329 |
+
left = 0
|
330 |
+
top = (original_height - new_height) / 2
|
331 |
+
right = original_width
|
332 |
+
bottom = top + new_height
|
333 |
+
|
334 |
+
# Crop and resize
|
335 |
+
cropped_image = image.crop((left, top, right, bottom))
|
336 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
337 |
+
return resized_image
|
338 |
+
|
339 |
+
|
340 |
+
## Ordinary function
|
341 |
+
def resize(image: Image.Image,
|
342 |
+
target_width: int,
|
343 |
+
target_height: int) -> Image.Image:
|
344 |
+
"""
|
345 |
+
Crops and resizes an image while preserving the aspect ratio.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
349 |
+
target_width (int): Target width of the output image.
|
350 |
+
target_height (int): Target height of the output image.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
Image.Image: Cropped and resized image.
|
354 |
+
"""
|
355 |
+
# Original dimensions
|
356 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
357 |
+
return resized_image
|
358 |
+
|
359 |
+
|
360 |
+
def move_mask_func(mask, direction, units):
|
361 |
+
binary_mask = mask.squeeze()>0
|
362 |
+
rows, cols = binary_mask.shape
|
363 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
364 |
+
|
365 |
+
if direction == 'down':
|
366 |
+
# move down
|
367 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
368 |
+
|
369 |
+
elif direction == 'up':
|
370 |
+
# move up
|
371 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
372 |
+
|
373 |
+
elif direction == 'right':
|
374 |
+
# move left
|
375 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
376 |
+
|
377 |
+
elif direction == 'left':
|
378 |
+
# move right
|
379 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
380 |
+
|
381 |
+
return moved_mask
|
382 |
+
|
383 |
+
|
384 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
385 |
+
# Randomly select the size of dilation
|
386 |
+
binary_mask = mask.squeeze()>0
|
387 |
+
|
388 |
+
if dilation_type == 'square_dilation':
|
389 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
390 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
391 |
+
elif dilation_type == 'square_erosion':
|
392 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
393 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
394 |
+
elif dilation_type == 'bounding_box':
|
395 |
+
# find the most left top and left bottom point
|
396 |
+
rows, cols = np.where(binary_mask)
|
397 |
+
if len(rows) == 0 or len(cols) == 0:
|
398 |
+
return mask # return original mask if no valid points
|
399 |
+
|
400 |
+
min_row = np.min(rows)
|
401 |
+
max_row = np.max(rows)
|
402 |
+
min_col = np.min(cols)
|
403 |
+
max_col = np.max(cols)
|
404 |
+
|
405 |
+
# create a bounding box
|
406 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
407 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
408 |
+
|
409 |
+
elif dilation_type == 'bounding_ellipse':
|
410 |
+
# find the most left top and left bottom point
|
411 |
+
rows, cols = np.where(binary_mask)
|
412 |
+
if len(rows) == 0 or len(cols) == 0:
|
413 |
+
return mask # return original mask if no valid points
|
414 |
+
|
415 |
+
min_row = np.min(rows)
|
416 |
+
max_row = np.max(rows)
|
417 |
+
min_col = np.min(cols)
|
418 |
+
max_col = np.max(cols)
|
419 |
+
|
420 |
+
# calculate the center and axis length of the ellipse
|
421 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
422 |
+
a = (max_col - min_col) // 2 # half long axis
|
423 |
+
b = (max_row - min_row) // 2 # half short axis
|
424 |
+
|
425 |
+
# create a bounding ellipse
|
426 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
427 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
428 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
429 |
+
dilated_mask[ellipse_mask] = True
|
430 |
+
else:
|
431 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
432 |
+
|
433 |
+
# use binary dilation
|
434 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
435 |
+
return dilated_mask
|
436 |
+
|
437 |
+
|
438 |
+
## Gradio component function
|
439 |
+
def update_vlm_model(vlm_name):
|
440 |
+
global vlm_model, vlm_processor
|
441 |
+
if vlm_model is not None:
|
442 |
+
del vlm_model
|
443 |
+
torch.cuda.empty_cache()
|
444 |
+
|
445 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
446 |
+
|
447 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
448 |
+
if vlm_type == "llava-next":
|
449 |
+
if vlm_processor != "" and vlm_model != "":
|
450 |
+
vlm_model.to(device)
|
451 |
+
return vlm_model_dropdown
|
452 |
+
else:
|
453 |
+
if os.path.exists(vlm_local_path):
|
454 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
455 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
456 |
+
else:
|
457 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
458 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
459 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
460 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
461 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
462 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
463 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
464 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
465 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
466 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
467 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
468 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
469 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
470 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
471 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
472 |
+
elif vlm_type == "qwen2-vl":
|
473 |
+
if vlm_processor != "" and vlm_model != "":
|
474 |
+
vlm_model.to(device)
|
475 |
+
return vlm_model_dropdown
|
476 |
+
else:
|
477 |
+
if os.path.exists(vlm_local_path):
|
478 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
479 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
480 |
+
else:
|
481 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
482 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
483 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
484 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
485 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
486 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
487 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
488 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
489 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
490 |
+
elif vlm_type == "openai":
|
491 |
+
pass
|
492 |
+
return "success"
|
493 |
+
|
494 |
+
|
495 |
+
def update_base_model(base_model_name):
|
496 |
+
global pipe
|
497 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
498 |
+
if pipe is not None:
|
499 |
+
del pipe
|
500 |
+
torch.cuda.empty_cache()
|
501 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
502 |
+
if pipe != "":
|
503 |
+
pipe.to(device)
|
504 |
+
else:
|
505 |
+
if os.path.exists(base_model_path):
|
506 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
507 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
508 |
+
)
|
509 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
510 |
+
pipe.enable_model_cpu_offload()
|
511 |
+
else:
|
512 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
513 |
+
return "success"
|
514 |
+
|
515 |
+
|
516 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
517 |
+
global vlm_model, vlm_processor
|
518 |
+
if vlm_model is not None:
|
519 |
+
del vlm_model
|
520 |
+
torch.cuda.empty_cache()
|
521 |
+
try:
|
522 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY)
|
523 |
+
vlm_processor = ""
|
524 |
+
response = vlm_model.chat.completions.create(
|
525 |
+
model="gpt-4o-2024-08-06",
|
526 |
+
messages=[
|
527 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
528 |
+
{"role": "user", "content": "Say this is a test"}
|
529 |
+
]
|
530 |
+
)
|
531 |
+
response_str = response.choices[0].message.content
|
532 |
+
|
533 |
+
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
|
534 |
+
except Exception as e:
|
535 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
536 |
+
|
537 |
+
|
538 |
+
|
539 |
+
def process(input_image,
|
540 |
+
original_image,
|
541 |
+
original_mask,
|
542 |
+
prompt,
|
543 |
+
negative_prompt,
|
544 |
+
control_strength,
|
545 |
+
seed,
|
546 |
+
randomize_seed,
|
547 |
+
guidance_scale,
|
548 |
+
num_inference_steps,
|
549 |
+
num_samples,
|
550 |
+
blending,
|
551 |
+
category,
|
552 |
+
target_prompt,
|
553 |
+
resize_default,
|
554 |
+
aspect_ratio_name,
|
555 |
+
invert_mask_state):
|
556 |
+
if original_image is None:
|
557 |
+
if input_image is None:
|
558 |
+
raise gr.Error('Please upload the input image')
|
559 |
+
else:
|
560 |
+
image_pil = input_image["background"].convert("RGB")
|
561 |
+
original_image = np.array(image_pil)
|
562 |
+
if prompt is None or prompt == "":
|
563 |
+
if target_prompt is None or target_prompt == "":
|
564 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
565 |
+
|
566 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
567 |
+
input_mask = np.asarray(alpha_mask)
|
568 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
569 |
+
if output_w == "" or output_h == "":
|
570 |
+
output_h, output_w = original_image.shape[:2]
|
571 |
+
|
572 |
+
if resize_default:
|
573 |
+
short_side = min(output_w, output_h)
|
574 |
+
scale_ratio = 640 / short_side
|
575 |
+
output_w = int(output_w * scale_ratio)
|
576 |
+
output_h = int(output_h * scale_ratio)
|
577 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
578 |
+
original_image = np.array(original_image)
|
579 |
+
if input_mask is not None:
|
580 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
581 |
+
input_mask = np.array(input_mask)
|
582 |
+
if original_mask is not None:
|
583 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
584 |
+
original_mask = np.array(original_mask)
|
585 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
586 |
+
else:
|
587 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
588 |
+
pass
|
589 |
+
else:
|
590 |
+
if resize_default:
|
591 |
+
short_side = min(output_w, output_h)
|
592 |
+
scale_ratio = 640 / short_side
|
593 |
+
output_w = int(output_w * scale_ratio)
|
594 |
+
output_h = int(output_h * scale_ratio)
|
595 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
596 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
597 |
+
original_image = np.array(original_image)
|
598 |
+
if input_mask is not None:
|
599 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
600 |
+
input_mask = np.array(input_mask)
|
601 |
+
if original_mask is not None:
|
602 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
603 |
+
original_mask = np.array(original_mask)
|
604 |
+
|
605 |
+
if invert_mask_state:
|
606 |
+
original_mask = original_mask
|
607 |
+
else:
|
608 |
+
if input_mask.max() == 0:
|
609 |
+
original_mask = original_mask
|
610 |
+
else:
|
611 |
+
original_mask = input_mask
|
612 |
+
|
613 |
+
|
614 |
+
## inpainting directly if target_prompt is not None
|
615 |
+
if category is not None:
|
616 |
+
pass
|
617 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
618 |
+
pass
|
619 |
+
else:
|
620 |
+
try:
|
621 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
622 |
+
except Exception as e:
|
623 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
624 |
+
|
625 |
+
|
626 |
+
if original_mask is not None:
|
627 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
628 |
+
else:
|
629 |
+
try:
|
630 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
631 |
+
vlm_processor,
|
632 |
+
vlm_model,
|
633 |
+
original_image,
|
634 |
+
category,
|
635 |
+
prompt,
|
636 |
+
device)
|
637 |
+
|
638 |
+
original_mask = vlm_response_mask(vlm_processor,
|
639 |
+
vlm_model,
|
640 |
+
category,
|
641 |
+
original_image,
|
642 |
+
prompt,
|
643 |
+
object_wait_for_edit,
|
644 |
+
sam,
|
645 |
+
sam_predictor,
|
646 |
+
sam_automask_generator,
|
647 |
+
groundingdino_model,
|
648 |
+
device).astype(np.uint8)
|
649 |
+
except Exception as e:
|
650 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
651 |
+
|
652 |
+
if original_mask.ndim == 2:
|
653 |
+
original_mask = original_mask[:,:,None]
|
654 |
+
|
655 |
+
|
656 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
657 |
+
prompt_after_apply_instruction = target_prompt
|
658 |
+
|
659 |
+
else:
|
660 |
+
try:
|
661 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
662 |
+
vlm_processor,
|
663 |
+
vlm_model,
|
664 |
+
original_image,
|
665 |
+
prompt,
|
666 |
+
device)
|
667 |
+
except Exception as e:
|
668 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
669 |
+
|
670 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
671 |
+
|
672 |
+
|
673 |
+
with torch.autocast(device):
|
674 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
675 |
+
prompt_after_apply_instruction,
|
676 |
+
original_mask,
|
677 |
+
original_image,
|
678 |
+
generator,
|
679 |
+
num_inference_steps,
|
680 |
+
guidance_scale,
|
681 |
+
control_strength,
|
682 |
+
negative_prompt,
|
683 |
+
num_samples,
|
684 |
+
blending)
|
685 |
+
original_image = np.array(init_image_np)
|
686 |
+
masked_image = original_image * (1 - (mask_np>0))
|
687 |
+
masked_image = masked_image.astype(np.uint8)
|
688 |
+
masked_image = Image.fromarray(masked_image)
|
689 |
+
# Save the images (optional)
|
690 |
+
# import uuid
|
691 |
+
# uuid = str(uuid.uuid4())
|
692 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
693 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
694 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
695 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
696 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
697 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
698 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
699 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
700 |
+
|
701 |
+
|
702 |
+
def generate_target_prompt(input_image,
|
703 |
+
original_image,
|
704 |
+
prompt):
|
705 |
+
# load example image
|
706 |
+
if isinstance(original_image, str):
|
707 |
+
original_image = input_image
|
708 |
+
|
709 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
710 |
+
vlm_processor,
|
711 |
+
vlm_model,
|
712 |
+
original_image,
|
713 |
+
prompt,
|
714 |
+
device)
|
715 |
+
return prompt_after_apply_instruction
|
716 |
+
|
717 |
+
|
718 |
+
def process_mask(input_image,
|
719 |
+
original_image,
|
720 |
+
prompt,
|
721 |
+
resize_default,
|
722 |
+
aspect_ratio_name):
|
723 |
+
if original_image is None:
|
724 |
+
raise gr.Error('Please upload the input image')
|
725 |
+
if prompt is None:
|
726 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
727 |
+
|
728 |
+
## load mask
|
729 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
730 |
+
input_mask = np.array(alpha_mask)
|
731 |
+
|
732 |
+
# load example image
|
733 |
+
if isinstance(original_image, str):
|
734 |
+
original_image = input_image["background"]
|
735 |
+
|
736 |
+
if input_mask.max() == 0:
|
737 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
738 |
+
|
739 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
740 |
+
vlm_model,
|
741 |
+
original_image,
|
742 |
+
category,
|
743 |
+
prompt,
|
744 |
+
device)
|
745 |
+
# original mask: h,w,1 [0, 255]
|
746 |
+
original_mask = vlm_response_mask(
|
747 |
+
vlm_processor,
|
748 |
+
vlm_model,
|
749 |
+
category,
|
750 |
+
original_image,
|
751 |
+
prompt,
|
752 |
+
object_wait_for_edit,
|
753 |
+
sam,
|
754 |
+
sam_predictor,
|
755 |
+
sam_automask_generator,
|
756 |
+
groundingdino_model,
|
757 |
+
device).astype(np.uint8)
|
758 |
+
else:
|
759 |
+
original_mask = input_mask.astype(np.uint8)
|
760 |
+
category = None
|
761 |
+
|
762 |
+
## resize mask if needed
|
763 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
764 |
+
if output_w == "" or output_h == "":
|
765 |
+
output_h, output_w = original_image.shape[:2]
|
766 |
+
if resize_default:
|
767 |
+
short_side = min(output_w, output_h)
|
768 |
+
scale_ratio = 640 / short_side
|
769 |
+
output_w = int(output_w * scale_ratio)
|
770 |
+
output_h = int(output_h * scale_ratio)
|
771 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
772 |
+
original_image = np.array(original_image)
|
773 |
+
if input_mask is not None:
|
774 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
775 |
+
input_mask = np.array(input_mask)
|
776 |
+
if original_mask is not None:
|
777 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
778 |
+
original_mask = np.array(original_mask)
|
779 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
780 |
+
else:
|
781 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
782 |
+
pass
|
783 |
+
else:
|
784 |
+
if resize_default:
|
785 |
+
short_side = min(output_w, output_h)
|
786 |
+
scale_ratio = 640 / short_side
|
787 |
+
output_w = int(output_w * scale_ratio)
|
788 |
+
output_h = int(output_h * scale_ratio)
|
789 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
790 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
791 |
+
original_image = np.array(original_image)
|
792 |
+
if input_mask is not None:
|
793 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
794 |
+
input_mask = np.array(input_mask)
|
795 |
+
if original_mask is not None:
|
796 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
797 |
+
original_mask = np.array(original_mask)
|
798 |
+
|
799 |
+
|
800 |
+
if original_mask.ndim == 2:
|
801 |
+
original_mask = original_mask[:,:,None]
|
802 |
+
|
803 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
804 |
+
|
805 |
+
masked_image = original_image * (1 - (original_mask>0))
|
806 |
+
masked_image = masked_image.astype(np.uint8)
|
807 |
+
masked_image = Image.fromarray(masked_image)
|
808 |
+
|
809 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
810 |
+
|
811 |
+
|
812 |
+
def process_random_mask(input_image,
|
813 |
+
original_image,
|
814 |
+
original_mask,
|
815 |
+
resize_default,
|
816 |
+
aspect_ratio_name,
|
817 |
+
):
|
818 |
+
|
819 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
820 |
+
input_mask = np.asarray(alpha_mask)
|
821 |
+
|
822 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
823 |
+
if output_w == "" or output_h == "":
|
824 |
+
output_h, output_w = original_image.shape[:2]
|
825 |
+
if resize_default:
|
826 |
+
short_side = min(output_w, output_h)
|
827 |
+
scale_ratio = 640 / short_side
|
828 |
+
output_w = int(output_w * scale_ratio)
|
829 |
+
output_h = int(output_h * scale_ratio)
|
830 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
831 |
+
original_image = np.array(original_image)
|
832 |
+
if input_mask is not None:
|
833 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
834 |
+
input_mask = np.array(input_mask)
|
835 |
+
if original_mask is not None:
|
836 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
837 |
+
original_mask = np.array(original_mask)
|
838 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
839 |
+
else:
|
840 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
841 |
+
pass
|
842 |
+
else:
|
843 |
+
if resize_default:
|
844 |
+
short_side = min(output_w, output_h)
|
845 |
+
scale_ratio = 640 / short_side
|
846 |
+
output_w = int(output_w * scale_ratio)
|
847 |
+
output_h = int(output_h * scale_ratio)
|
848 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
849 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
850 |
+
original_image = np.array(original_image)
|
851 |
+
if input_mask is not None:
|
852 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
853 |
+
input_mask = np.array(input_mask)
|
854 |
+
if original_mask is not None:
|
855 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
856 |
+
original_mask = np.array(original_mask)
|
857 |
+
|
858 |
+
|
859 |
+
if input_mask.max() == 0:
|
860 |
+
original_mask = original_mask
|
861 |
+
else:
|
862 |
+
original_mask = input_mask
|
863 |
+
|
864 |
+
if original_mask is None:
|
865 |
+
raise gr.Error('Please generate mask first')
|
866 |
+
|
867 |
+
if original_mask.ndim == 2:
|
868 |
+
original_mask = original_mask[:,:,None]
|
869 |
+
|
870 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
871 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
872 |
+
|
873 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
874 |
+
|
875 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
876 |
+
masked_image = masked_image.astype(original_image.dtype)
|
877 |
+
masked_image = Image.fromarray(masked_image)
|
878 |
+
|
879 |
+
|
880 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
881 |
+
|
882 |
+
|
883 |
+
def process_dilation_mask(input_image,
|
884 |
+
original_image,
|
885 |
+
original_mask,
|
886 |
+
resize_default,
|
887 |
+
aspect_ratio_name,
|
888 |
+
dilation_size=20):
|
889 |
+
|
890 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
891 |
+
input_mask = np.asarray(alpha_mask)
|
892 |
+
|
893 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
894 |
+
if output_w == "" or output_h == "":
|
895 |
+
output_h, output_w = original_image.shape[:2]
|
896 |
+
if resize_default:
|
897 |
+
short_side = min(output_w, output_h)
|
898 |
+
scale_ratio = 640 / short_side
|
899 |
+
output_w = int(output_w * scale_ratio)
|
900 |
+
output_h = int(output_h * scale_ratio)
|
901 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
902 |
+
original_image = np.array(original_image)
|
903 |
+
if input_mask is not None:
|
904 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
905 |
+
input_mask = np.array(input_mask)
|
906 |
+
if original_mask is not None:
|
907 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
908 |
+
original_mask = np.array(original_mask)
|
909 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
910 |
+
else:
|
911 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
912 |
+
pass
|
913 |
+
else:
|
914 |
+
if resize_default:
|
915 |
+
short_side = min(output_w, output_h)
|
916 |
+
scale_ratio = 640 / short_side
|
917 |
+
output_w = int(output_w * scale_ratio)
|
918 |
+
output_h = int(output_h * scale_ratio)
|
919 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
920 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
921 |
+
original_image = np.array(original_image)
|
922 |
+
if input_mask is not None:
|
923 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
924 |
+
input_mask = np.array(input_mask)
|
925 |
+
if original_mask is not None:
|
926 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
927 |
+
original_mask = np.array(original_mask)
|
928 |
+
|
929 |
+
if input_mask.max() == 0:
|
930 |
+
original_mask = original_mask
|
931 |
+
else:
|
932 |
+
original_mask = input_mask
|
933 |
+
|
934 |
+
if original_mask is None:
|
935 |
+
raise gr.Error('Please generate mask first')
|
936 |
+
|
937 |
+
if original_mask.ndim == 2:
|
938 |
+
original_mask = original_mask[:,:,None]
|
939 |
+
|
940 |
+
dilation_type = np.random.choice(['square_dilation'])
|
941 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
942 |
+
|
943 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
944 |
+
|
945 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
946 |
+
masked_image = masked_image.astype(original_image.dtype)
|
947 |
+
masked_image = Image.fromarray(masked_image)
|
948 |
+
|
949 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
950 |
+
|
951 |
+
|
952 |
+
def process_erosion_mask(input_image,
|
953 |
+
original_image,
|
954 |
+
original_mask,
|
955 |
+
resize_default,
|
956 |
+
aspect_ratio_name,
|
957 |
+
dilation_size=20):
|
958 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
959 |
+
input_mask = np.asarray(alpha_mask)
|
960 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
961 |
+
if output_w == "" or output_h == "":
|
962 |
+
output_h, output_w = original_image.shape[:2]
|
963 |
+
if resize_default:
|
964 |
+
short_side = min(output_w, output_h)
|
965 |
+
scale_ratio = 640 / short_side
|
966 |
+
output_w = int(output_w * scale_ratio)
|
967 |
+
output_h = int(output_h * scale_ratio)
|
968 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
969 |
+
original_image = np.array(original_image)
|
970 |
+
if input_mask is not None:
|
971 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
972 |
+
input_mask = np.array(input_mask)
|
973 |
+
if original_mask is not None:
|
974 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
975 |
+
original_mask = np.array(original_mask)
|
976 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
977 |
+
else:
|
978 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
979 |
+
pass
|
980 |
+
else:
|
981 |
+
if resize_default:
|
982 |
+
short_side = min(output_w, output_h)
|
983 |
+
scale_ratio = 640 / short_side
|
984 |
+
output_w = int(output_w * scale_ratio)
|
985 |
+
output_h = int(output_h * scale_ratio)
|
986 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
987 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
988 |
+
original_image = np.array(original_image)
|
989 |
+
if input_mask is not None:
|
990 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
991 |
+
input_mask = np.array(input_mask)
|
992 |
+
if original_mask is not None:
|
993 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
994 |
+
original_mask = np.array(original_mask)
|
995 |
+
|
996 |
+
if input_mask.max() == 0:
|
997 |
+
original_mask = original_mask
|
998 |
+
else:
|
999 |
+
original_mask = input_mask
|
1000 |
+
|
1001 |
+
if original_mask is None:
|
1002 |
+
raise gr.Error('Please generate mask first')
|
1003 |
+
|
1004 |
+
if original_mask.ndim == 2:
|
1005 |
+
original_mask = original_mask[:,:,None]
|
1006 |
+
|
1007 |
+
dilation_type = np.random.choice(['square_erosion'])
|
1008 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1009 |
+
|
1010 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1011 |
+
|
1012 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1013 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1014 |
+
masked_image = Image.fromarray(masked_image)
|
1015 |
+
|
1016 |
+
|
1017 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1018 |
+
|
1019 |
+
|
1020 |
+
def move_mask_left(input_image,
|
1021 |
+
original_image,
|
1022 |
+
original_mask,
|
1023 |
+
moving_pixels,
|
1024 |
+
resize_default,
|
1025 |
+
aspect_ratio_name):
|
1026 |
+
|
1027 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1028 |
+
input_mask = np.asarray(alpha_mask)
|
1029 |
+
|
1030 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1031 |
+
if output_w == "" or output_h == "":
|
1032 |
+
output_h, output_w = original_image.shape[:2]
|
1033 |
+
if resize_default:
|
1034 |
+
short_side = min(output_w, output_h)
|
1035 |
+
scale_ratio = 640 / short_side
|
1036 |
+
output_w = int(output_w * scale_ratio)
|
1037 |
+
output_h = int(output_h * scale_ratio)
|
1038 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1039 |
+
original_image = np.array(original_image)
|
1040 |
+
if input_mask is not None:
|
1041 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1042 |
+
input_mask = np.array(input_mask)
|
1043 |
+
if original_mask is not None:
|
1044 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1045 |
+
original_mask = np.array(original_mask)
|
1046 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1047 |
+
else:
|
1048 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1049 |
+
pass
|
1050 |
+
else:
|
1051 |
+
if resize_default:
|
1052 |
+
short_side = min(output_w, output_h)
|
1053 |
+
scale_ratio = 640 / short_side
|
1054 |
+
output_w = int(output_w * scale_ratio)
|
1055 |
+
output_h = int(output_h * scale_ratio)
|
1056 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1057 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1058 |
+
original_image = np.array(original_image)
|
1059 |
+
if input_mask is not None:
|
1060 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1061 |
+
input_mask = np.array(input_mask)
|
1062 |
+
if original_mask is not None:
|
1063 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1064 |
+
original_mask = np.array(original_mask)
|
1065 |
+
|
1066 |
+
if input_mask.max() == 0:
|
1067 |
+
original_mask = original_mask
|
1068 |
+
else:
|
1069 |
+
original_mask = input_mask
|
1070 |
+
|
1071 |
+
if original_mask is None:
|
1072 |
+
raise gr.Error('Please generate mask first')
|
1073 |
+
|
1074 |
+
if original_mask.ndim == 2:
|
1075 |
+
original_mask = original_mask[:,:,None]
|
1076 |
+
|
1077 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1078 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1079 |
+
|
1080 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1081 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1082 |
+
masked_image = Image.fromarray(masked_image)
|
1083 |
+
|
1084 |
+
if moved_mask.max() <= 1:
|
1085 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1086 |
+
original_mask = moved_mask
|
1087 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1088 |
+
|
1089 |
+
|
1090 |
+
def move_mask_right(input_image,
|
1091 |
+
original_image,
|
1092 |
+
original_mask,
|
1093 |
+
moving_pixels,
|
1094 |
+
resize_default,
|
1095 |
+
aspect_ratio_name):
|
1096 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1097 |
+
input_mask = np.asarray(alpha_mask)
|
1098 |
+
|
1099 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1100 |
+
if output_w == "" or output_h == "":
|
1101 |
+
output_h, output_w = original_image.shape[:2]
|
1102 |
+
if resize_default:
|
1103 |
+
short_side = min(output_w, output_h)
|
1104 |
+
scale_ratio = 640 / short_side
|
1105 |
+
output_w = int(output_w * scale_ratio)
|
1106 |
+
output_h = int(output_h * scale_ratio)
|
1107 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1108 |
+
original_image = np.array(original_image)
|
1109 |
+
if input_mask is not None:
|
1110 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1111 |
+
input_mask = np.array(input_mask)
|
1112 |
+
if original_mask is not None:
|
1113 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1114 |
+
original_mask = np.array(original_mask)
|
1115 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1116 |
+
else:
|
1117 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1118 |
+
pass
|
1119 |
+
else:
|
1120 |
+
if resize_default:
|
1121 |
+
short_side = min(output_w, output_h)
|
1122 |
+
scale_ratio = 640 / short_side
|
1123 |
+
output_w = int(output_w * scale_ratio)
|
1124 |
+
output_h = int(output_h * scale_ratio)
|
1125 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1126 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1127 |
+
original_image = np.array(original_image)
|
1128 |
+
if input_mask is not None:
|
1129 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1130 |
+
input_mask = np.array(input_mask)
|
1131 |
+
if original_mask is not None:
|
1132 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1133 |
+
original_mask = np.array(original_mask)
|
1134 |
+
|
1135 |
+
if input_mask.max() == 0:
|
1136 |
+
original_mask = original_mask
|
1137 |
+
else:
|
1138 |
+
original_mask = input_mask
|
1139 |
+
|
1140 |
+
if original_mask is None:
|
1141 |
+
raise gr.Error('Please generate mask first')
|
1142 |
+
|
1143 |
+
if original_mask.ndim == 2:
|
1144 |
+
original_mask = original_mask[:,:,None]
|
1145 |
+
|
1146 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1147 |
+
|
1148 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1149 |
+
|
1150 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1151 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1152 |
+
masked_image = Image.fromarray(masked_image)
|
1153 |
+
|
1154 |
+
|
1155 |
+
if moved_mask.max() <= 1:
|
1156 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1157 |
+
original_mask = moved_mask
|
1158 |
+
|
1159 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1160 |
+
|
1161 |
+
|
1162 |
+
def move_mask_up(input_image,
|
1163 |
+
original_image,
|
1164 |
+
original_mask,
|
1165 |
+
moving_pixels,
|
1166 |
+
resize_default,
|
1167 |
+
aspect_ratio_name):
|
1168 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1169 |
+
input_mask = np.asarray(alpha_mask)
|
1170 |
+
|
1171 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1172 |
+
if output_w == "" or output_h == "":
|
1173 |
+
output_h, output_w = original_image.shape[:2]
|
1174 |
+
if resize_default:
|
1175 |
+
short_side = min(output_w, output_h)
|
1176 |
+
scale_ratio = 640 / short_side
|
1177 |
+
output_w = int(output_w * scale_ratio)
|
1178 |
+
output_h = int(output_h * scale_ratio)
|
1179 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1180 |
+
original_image = np.array(original_image)
|
1181 |
+
if input_mask is not None:
|
1182 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1183 |
+
input_mask = np.array(input_mask)
|
1184 |
+
if original_mask is not None:
|
1185 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1186 |
+
original_mask = np.array(original_mask)
|
1187 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1188 |
+
else:
|
1189 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1190 |
+
pass
|
1191 |
+
else:
|
1192 |
+
if resize_default:
|
1193 |
+
short_side = min(output_w, output_h)
|
1194 |
+
scale_ratio = 640 / short_side
|
1195 |
+
output_w = int(output_w * scale_ratio)
|
1196 |
+
output_h = int(output_h * scale_ratio)
|
1197 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1198 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1199 |
+
original_image = np.array(original_image)
|
1200 |
+
if input_mask is not None:
|
1201 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1202 |
+
input_mask = np.array(input_mask)
|
1203 |
+
if original_mask is not None:
|
1204 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1205 |
+
original_mask = np.array(original_mask)
|
1206 |
+
|
1207 |
+
if input_mask.max() == 0:
|
1208 |
+
original_mask = original_mask
|
1209 |
+
else:
|
1210 |
+
original_mask = input_mask
|
1211 |
+
|
1212 |
+
if original_mask is None:
|
1213 |
+
raise gr.Error('Please generate mask first')
|
1214 |
+
|
1215 |
+
if original_mask.ndim == 2:
|
1216 |
+
original_mask = original_mask[:,:,None]
|
1217 |
+
|
1218 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1219 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1220 |
+
|
1221 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1222 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1223 |
+
masked_image = Image.fromarray(masked_image)
|
1224 |
+
|
1225 |
+
if moved_mask.max() <= 1:
|
1226 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1227 |
+
original_mask = moved_mask
|
1228 |
+
|
1229 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1230 |
+
|
1231 |
+
|
1232 |
+
def move_mask_down(input_image,
|
1233 |
+
original_image,
|
1234 |
+
original_mask,
|
1235 |
+
moving_pixels,
|
1236 |
+
resize_default,
|
1237 |
+
aspect_ratio_name):
|
1238 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1239 |
+
input_mask = np.asarray(alpha_mask)
|
1240 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1241 |
+
if output_w == "" or output_h == "":
|
1242 |
+
output_h, output_w = original_image.shape[:2]
|
1243 |
+
if resize_default:
|
1244 |
+
short_side = min(output_w, output_h)
|
1245 |
+
scale_ratio = 640 / short_side
|
1246 |
+
output_w = int(output_w * scale_ratio)
|
1247 |
+
output_h = int(output_h * scale_ratio)
|
1248 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1249 |
+
original_image = np.array(original_image)
|
1250 |
+
if input_mask is not None:
|
1251 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1252 |
+
input_mask = np.array(input_mask)
|
1253 |
+
if original_mask is not None:
|
1254 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1255 |
+
original_mask = np.array(original_mask)
|
1256 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1257 |
+
else:
|
1258 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1259 |
+
pass
|
1260 |
+
else:
|
1261 |
+
if resize_default:
|
1262 |
+
short_side = min(output_w, output_h)
|
1263 |
+
scale_ratio = 640 / short_side
|
1264 |
+
output_w = int(output_w * scale_ratio)
|
1265 |
+
output_h = int(output_h * scale_ratio)
|
1266 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1267 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1268 |
+
original_image = np.array(original_image)
|
1269 |
+
if input_mask is not None:
|
1270 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1271 |
+
input_mask = np.array(input_mask)
|
1272 |
+
if original_mask is not None:
|
1273 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1274 |
+
original_mask = np.array(original_mask)
|
1275 |
+
|
1276 |
+
if input_mask.max() == 0:
|
1277 |
+
original_mask = original_mask
|
1278 |
+
else:
|
1279 |
+
original_mask = input_mask
|
1280 |
+
|
1281 |
+
if original_mask is None:
|
1282 |
+
raise gr.Error('Please generate mask first')
|
1283 |
+
|
1284 |
+
if original_mask.ndim == 2:
|
1285 |
+
original_mask = original_mask[:,:,None]
|
1286 |
+
|
1287 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1288 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1289 |
+
|
1290 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1291 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1292 |
+
masked_image = Image.fromarray(masked_image)
|
1293 |
+
|
1294 |
+
if moved_mask.max() <= 1:
|
1295 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1296 |
+
original_mask = moved_mask
|
1297 |
+
|
1298 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1299 |
+
|
1300 |
+
|
1301 |
+
def invert_mask(input_image,
|
1302 |
+
original_image,
|
1303 |
+
original_mask,
|
1304 |
+
):
|
1305 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1306 |
+
input_mask = np.asarray(alpha_mask)
|
1307 |
+
if input_mask.max() == 0:
|
1308 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1309 |
+
else:
|
1310 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1311 |
+
|
1312 |
+
if original_mask is None:
|
1313 |
+
raise gr.Error('Please generate mask first')
|
1314 |
+
|
1315 |
+
original_mask = original_mask.squeeze()
|
1316 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1317 |
+
|
1318 |
+
if original_mask.ndim == 2:
|
1319 |
+
original_mask = original_mask[:,:,None]
|
1320 |
+
|
1321 |
+
if original_mask.max() <= 1:
|
1322 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1323 |
+
|
1324 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1325 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1326 |
+
masked_image = Image.fromarray(masked_image)
|
1327 |
+
|
1328 |
+
return [masked_image], [mask_image], original_mask, True
|
1329 |
+
|
1330 |
+
|
1331 |
+
def init_img(base,
|
1332 |
+
init_type,
|
1333 |
+
prompt,
|
1334 |
+
aspect_ratio,
|
1335 |
+
example_change_times
|
1336 |
+
):
|
1337 |
+
image_pil = base["background"].convert("RGB")
|
1338 |
+
original_image = np.array(image_pil)
|
1339 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1340 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1341 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1342 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1343 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1344 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1345 |
+
width, height = image_pil.size
|
1346 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1347 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1348 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1349 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1350 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1351 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1352 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1353 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1354 |
+
else:
|
1355 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1356 |
+
aspect_ratio = "Custom resolution"
|
1357 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1358 |
+
|
1359 |
+
|
1360 |
+
def reset_func(input_image,
|
1361 |
+
original_image,
|
1362 |
+
original_mask,
|
1363 |
+
prompt,
|
1364 |
+
target_prompt,
|
1365 |
+
):
|
1366 |
+
input_image = None
|
1367 |
+
original_image = None
|
1368 |
+
original_mask = None
|
1369 |
+
prompt = ''
|
1370 |
+
mask_gallery = []
|
1371 |
+
masked_gallery = []
|
1372 |
+
result_gallery = []
|
1373 |
+
target_prompt = ''
|
1374 |
+
if torch.cuda.is_available():
|
1375 |
+
torch.cuda.empty_cache()
|
1376 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1377 |
+
|
1378 |
+
|
1379 |
+
def update_example(example_type,
|
1380 |
+
prompt,
|
1381 |
+
example_change_times):
|
1382 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1383 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1384 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1385 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1386 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1387 |
+
width, height = image_pil.size
|
1388 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1389 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1390 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1391 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1392 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1393 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1394 |
+
|
1395 |
+
original_image = np.array(image_pil)
|
1396 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1397 |
+
aspect_ratio = "Custom resolution"
|
1398 |
+
example_change_times += 1
|
1399 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1400 |
+
|
1401 |
+
|
1402 |
+
block = gr.Blocks(
|
1403 |
+
theme=gr.themes.Soft(
|
1404 |
+
radius_size=gr.themes.sizes.radius_none,
|
1405 |
+
text_size=gr.themes.sizes.text_md
|
1406 |
+
)
|
1407 |
+
)
|
1408 |
+
with block as demo:
|
1409 |
+
with gr.Row():
|
1410 |
+
with gr.Column():
|
1411 |
+
gr.HTML(head)
|
1412 |
+
|
1413 |
+
gr.Markdown(descriptions)
|
1414 |
+
|
1415 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1416 |
+
with gr.Row(equal_height=True):
|
1417 |
+
gr.Markdown(instructions)
|
1418 |
+
|
1419 |
+
original_image = gr.State(value=None)
|
1420 |
+
original_mask = gr.State(value=None)
|
1421 |
+
category = gr.State(value=None)
|
1422 |
+
status = gr.State(value=None)
|
1423 |
+
invert_mask_state = gr.State(value=False)
|
1424 |
+
example_change_times = gr.State(value=0)
|
1425 |
+
|
1426 |
+
|
1427 |
+
with gr.Row():
|
1428 |
+
with gr.Column():
|
1429 |
+
with gr.Row():
|
1430 |
+
input_image = gr.ImageEditor(
|
1431 |
+
label="Input Image",
|
1432 |
+
type="pil",
|
1433 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1434 |
+
layers = False,
|
1435 |
+
interactive=True,
|
1436 |
+
height=1024,
|
1437 |
+
sources=["upload"],
|
1438 |
+
placeholder="Please click here or the icon below to upload the image.",
|
1439 |
+
)
|
1440 |
+
|
1441 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1442 |
+
run_button = gr.Button("💫 Run")
|
1443 |
+
|
1444 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1445 |
+
with gr.Group():
|
1446 |
+
with gr.Row():
|
1447 |
+
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1448 |
+
|
1449 |
+
GPT4o_KEY_submit = gr.Button("Submit and Verify")
|
1450 |
+
|
1451 |
+
|
1452 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1453 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1454 |
+
|
1455 |
+
with gr.Row():
|
1456 |
+
mask_button = gr.Button("Generate Mask")
|
1457 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1458 |
+
|
1459 |
+
|
1460 |
+
with gr.Row():
|
1461 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1462 |
+
|
1463 |
+
target_prompt = gr.Text(
|
1464 |
+
label="Input Target Prompt",
|
1465 |
+
max_lines=5,
|
1466 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1467 |
+
value='',
|
1468 |
+
lines=2
|
1469 |
+
)
|
1470 |
+
|
1471 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1472 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1473 |
+
negative_prompt = gr.Text(
|
1474 |
+
label="Negative Prompt",
|
1475 |
+
max_lines=5,
|
1476 |
+
placeholder="Please input your negative prompt",
|
1477 |
+
value='ugly, low quality',lines=1
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
control_strength = gr.Slider(
|
1481 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1482 |
+
)
|
1483 |
+
with gr.Group():
|
1484 |
+
seed = gr.Slider(
|
1485 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1486 |
+
)
|
1487 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1488 |
+
|
1489 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1490 |
+
|
1491 |
+
|
1492 |
+
num_samples = gr.Slider(
|
1493 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1494 |
+
)
|
1495 |
+
|
1496 |
+
with gr.Group():
|
1497 |
+
with gr.Row():
|
1498 |
+
guidance_scale = gr.Slider(
|
1499 |
+
label="Guidance scale",
|
1500 |
+
minimum=1,
|
1501 |
+
maximum=12,
|
1502 |
+
step=0.1,
|
1503 |
+
value=7.5,
|
1504 |
+
)
|
1505 |
+
num_inference_steps = gr.Slider(
|
1506 |
+
label="Number of inference steps",
|
1507 |
+
minimum=1,
|
1508 |
+
maximum=50,
|
1509 |
+
step=1,
|
1510 |
+
value=50,
|
1511 |
+
)
|
1512 |
+
|
1513 |
+
|
1514 |
+
with gr.Column():
|
1515 |
+
with gr.Row():
|
1516 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1517 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1518 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1519 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1520 |
+
|
1521 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1522 |
+
dilation_size = gr.Slider(
|
1523 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1524 |
+
)
|
1525 |
+
with gr.Row():
|
1526 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1527 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1528 |
+
|
1529 |
+
moving_pixels = gr.Slider(
|
1530 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1531 |
+
)
|
1532 |
+
with gr.Row():
|
1533 |
+
move_left_button = gr.Button("Move Left")
|
1534 |
+
move_right_button = gr.Button("Move Right")
|
1535 |
+
with gr.Row():
|
1536 |
+
move_up_button = gr.Button("Move Up")
|
1537 |
+
move_down_button = gr.Button("Move Down")
|
1538 |
+
|
1539 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1540 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1541 |
+
|
1542 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1543 |
+
|
1544 |
+
reset_button = gr.Button("Reset")
|
1545 |
+
|
1546 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1547 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1548 |
+
|
1549 |
+
|
1550 |
+
|
1551 |
+
with gr.Row():
|
1552 |
+
example = gr.Examples(
|
1553 |
+
label="Quick Example",
|
1554 |
+
examples=EXAMPLES,
|
1555 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1556 |
+
examples_per_page=10,
|
1557 |
+
cache_examples=False,
|
1558 |
+
)
|
1559 |
+
|
1560 |
+
|
1561 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1562 |
+
with gr.Row(equal_height=True):
|
1563 |
+
gr.Markdown(tips)
|
1564 |
+
|
1565 |
+
with gr.Row():
|
1566 |
+
gr.Markdown(citation)
|
1567 |
+
|
1568 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1569 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1570 |
+
input_image.upload(
|
1571 |
+
init_img,
|
1572 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1573 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1574 |
+
)
|
1575 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1576 |
+
|
1577 |
+
## vlm and base model dropdown
|
1578 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1579 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1580 |
+
|
1581 |
+
|
1582 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1583 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1584 |
+
|
1585 |
+
|
1586 |
+
ips=[input_image,
|
1587 |
+
original_image,
|
1588 |
+
original_mask,
|
1589 |
+
prompt,
|
1590 |
+
negative_prompt,
|
1591 |
+
control_strength,
|
1592 |
+
seed,
|
1593 |
+
randomize_seed,
|
1594 |
+
guidance_scale,
|
1595 |
+
num_inference_steps,
|
1596 |
+
num_samples,
|
1597 |
+
blending,
|
1598 |
+
category,
|
1599 |
+
target_prompt,
|
1600 |
+
resize_default,
|
1601 |
+
aspect_ratio,
|
1602 |
+
invert_mask_state]
|
1603 |
+
|
1604 |
+
## run brushedit
|
1605 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1606 |
+
|
1607 |
+
## mask func
|
1608 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1609 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1610 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1611 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1612 |
+
|
1613 |
+
## move mask func
|
1614 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1615 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1616 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1617 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1618 |
+
|
1619 |
+
## prompt func
|
1620 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1621 |
+
|
1622 |
+
## reset func
|
1623 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1624 |
+
|
1625 |
+
## if have a localhost access error, try to use the following code
|
1626 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1627 |
+
# demo.launch()
|
brushedit_app_gradio_new.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
brushedit_app_new.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
brushedit_app_new_0404_cirr_blip1.py
ADDED
@@ -0,0 +1,2058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
import pandas as pd
|
10 |
+
import concurrent.futures
|
11 |
+
import faiss
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from pathlib import Path
|
15 |
+
import os
|
16 |
+
import json
|
17 |
+
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
import torch.nn.functional as F # 新增此行
|
21 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
22 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
23 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
24 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
25 |
+
|
26 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
27 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
28 |
+
from diffusers.image_processor import VaeImageProcessor
|
29 |
+
|
30 |
+
|
31 |
+
from app.src.vlm_pipeline import (
|
32 |
+
vlm_response_editing_type,
|
33 |
+
vlm_response_object_wait_for_edit,
|
34 |
+
vlm_response_mask,
|
35 |
+
vlm_response_prompt_after_apply_instruction
|
36 |
+
)
|
37 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
38 |
+
from app.utils.utils import load_grounding_dino_model
|
39 |
+
|
40 |
+
from app.src.vlm_template import vlms_template
|
41 |
+
from app.src.base_model_template import base_models_template
|
42 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
43 |
+
|
44 |
+
from openai import OpenAI
|
45 |
+
base_openai_url = "https://api.deepseek.com/"
|
46 |
+
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
|
47 |
+
|
48 |
+
|
49 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
50 |
+
from transformers import CLIPProcessor, CLIPModel
|
51 |
+
|
52 |
+
from app.deepseek.instructions import (
|
53 |
+
create_apply_editing_messages_deepseek,
|
54 |
+
create_decomposed_query_messages_deepseek
|
55 |
+
)
|
56 |
+
from clip_retrieval.clip_client import ClipClient
|
57 |
+
|
58 |
+
#### Description ####
|
59 |
+
logo = r"""
|
60 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
61 |
+
"""
|
62 |
+
head = r"""
|
63 |
+
<div style="text-align: center;">
|
64 |
+
<h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
|
65 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
66 |
+
<a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
67 |
+
<a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
68 |
+
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
69 |
+
|
70 |
+
</div>
|
71 |
+
</br>
|
72 |
+
</div>
|
73 |
+
"""
|
74 |
+
descriptions = r"""
|
75 |
+
Demo for ZS-CIR"""
|
76 |
+
|
77 |
+
instructions = r"""
|
78 |
+
Demo for ZS-CIR"""
|
79 |
+
|
80 |
+
tips = r"""
|
81 |
+
Demo for ZS-CIR
|
82 |
+
|
83 |
+
"""
|
84 |
+
|
85 |
+
citation = r"""
|
86 |
+
Demo for ZS-CIR"""
|
87 |
+
|
88 |
+
# - - - - - examples - - - - - #
|
89 |
+
EXAMPLES = [
|
90 |
+
|
91 |
+
[
|
92 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
93 |
+
"add a magic hat on frog head.",
|
94 |
+
642087011,
|
95 |
+
"frog",
|
96 |
+
"frog",
|
97 |
+
True,
|
98 |
+
False,
|
99 |
+
"GPT4-o (Highly Recommended)"
|
100 |
+
],
|
101 |
+
[
|
102 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
103 |
+
"replace the background to ancient China.",
|
104 |
+
648464818,
|
105 |
+
"chinese_girl",
|
106 |
+
"chinese_girl",
|
107 |
+
True,
|
108 |
+
False,
|
109 |
+
"GPT4-o (Highly Recommended)"
|
110 |
+
],
|
111 |
+
[
|
112 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
113 |
+
"remove the deer.",
|
114 |
+
648464818,
|
115 |
+
"angel_christmas",
|
116 |
+
"angel_christmas",
|
117 |
+
False,
|
118 |
+
False,
|
119 |
+
"GPT4-o (Highly Recommended)"
|
120 |
+
],
|
121 |
+
[
|
122 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
123 |
+
"add a wreath on head.",
|
124 |
+
648464818,
|
125 |
+
"sunflower_girl",
|
126 |
+
"sunflower_girl",
|
127 |
+
True,
|
128 |
+
False,
|
129 |
+
"GPT4-o (Highly Recommended)"
|
130 |
+
],
|
131 |
+
[
|
132 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
133 |
+
"add a butterfly fairy.",
|
134 |
+
648464818,
|
135 |
+
"girl_on_sun",
|
136 |
+
"girl_on_sun",
|
137 |
+
True,
|
138 |
+
False,
|
139 |
+
"GPT4-o (Highly Recommended)"
|
140 |
+
],
|
141 |
+
[
|
142 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
143 |
+
"remove the christmas hat.",
|
144 |
+
642087011,
|
145 |
+
"spider_man_rm",
|
146 |
+
"spider_man_rm",
|
147 |
+
False,
|
148 |
+
False,
|
149 |
+
"GPT4-o (Highly Recommended)"
|
150 |
+
],
|
151 |
+
[
|
152 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
153 |
+
"remove the flower.",
|
154 |
+
642087011,
|
155 |
+
"anime_flower",
|
156 |
+
"anime_flower",
|
157 |
+
False,
|
158 |
+
False,
|
159 |
+
"GPT4-o (Highly Recommended)"
|
160 |
+
],
|
161 |
+
[
|
162 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
163 |
+
"replace the clothes to a delicated floral skirt.",
|
164 |
+
648464818,
|
165 |
+
"chenduling",
|
166 |
+
"chenduling",
|
167 |
+
True,
|
168 |
+
False,
|
169 |
+
"GPT4-o (Highly Recommended)"
|
170 |
+
],
|
171 |
+
[
|
172 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
173 |
+
"make the hedgehog in Italy.",
|
174 |
+
648464818,
|
175 |
+
"hedgehog_rp_bg",
|
176 |
+
"hedgehog_rp_bg",
|
177 |
+
True,
|
178 |
+
False,
|
179 |
+
"GPT4-o (Highly Recommended)"
|
180 |
+
],
|
181 |
+
|
182 |
+
]
|
183 |
+
|
184 |
+
INPUT_IMAGE_PATH = {
|
185 |
+
"frog": "./assets/frog/frog.jpeg",
|
186 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
187 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
188 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
189 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
190 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
191 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
192 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
193 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
194 |
+
}
|
195 |
+
MASK_IMAGE_PATH = {
|
196 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
197 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
198 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
199 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
200 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
201 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
202 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
203 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
204 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
205 |
+
}
|
206 |
+
MASKED_IMAGE_PATH = {
|
207 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
208 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
209 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
210 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
211 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
212 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
213 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
214 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
215 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
216 |
+
}
|
217 |
+
OUTPUT_IMAGE_PATH = {
|
218 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
219 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
220 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
221 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
222 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
223 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
224 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
225 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
226 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
227 |
+
}
|
228 |
+
|
229 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
230 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
231 |
+
|
232 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
233 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
234 |
+
|
235 |
+
|
236 |
+
BASE_MODELS = list(base_models_template.keys())
|
237 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
238 |
+
|
239 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
240 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
241 |
+
|
242 |
+
|
243 |
+
## init device
|
244 |
+
try:
|
245 |
+
if torch.cuda.is_available():
|
246 |
+
device = "cuda"
|
247 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
248 |
+
device = "mps"
|
249 |
+
else:
|
250 |
+
device = "cpu"
|
251 |
+
except:
|
252 |
+
device = "cpu"
|
253 |
+
|
254 |
+
# ## init torch dtype
|
255 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
256 |
+
# torch_dtype = torch.bfloat16
|
257 |
+
# else:
|
258 |
+
# torch_dtype = torch.float16
|
259 |
+
|
260 |
+
# if device == "mps":
|
261 |
+
# torch_dtype = torch.float16
|
262 |
+
|
263 |
+
torch_dtype = torch.float16
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
# download hf models
|
268 |
+
BrushEdit_path = "models/"
|
269 |
+
if not os.path.exists(BrushEdit_path):
|
270 |
+
BrushEdit_path = snapshot_download(
|
271 |
+
repo_id="TencentARC/BrushEdit",
|
272 |
+
local_dir=BrushEdit_path,
|
273 |
+
token=os.getenv("HF_TOKEN"),
|
274 |
+
)
|
275 |
+
|
276 |
+
## init default VLM
|
277 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
278 |
+
if vlm_processor != "" and vlm_model != "":
|
279 |
+
vlm_model.to(device)
|
280 |
+
else:
|
281 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
282 |
+
|
283 |
+
## init default LLM
|
284 |
+
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
|
285 |
+
|
286 |
+
## init base model
|
287 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
288 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
289 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
290 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
291 |
+
|
292 |
+
|
293 |
+
# input brushnetX ckpt path
|
294 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
295 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
296 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
297 |
+
)
|
298 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
299 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
300 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
301 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
302 |
+
pipe.enable_model_cpu_offload()
|
303 |
+
|
304 |
+
|
305 |
+
## init SAM
|
306 |
+
sam = build_sam(checkpoint=sam_path)
|
307 |
+
sam.to(device=device)
|
308 |
+
sam_predictor = SamPredictor(sam)
|
309 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
310 |
+
|
311 |
+
## init groundingdino_model
|
312 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
313 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
314 |
+
|
315 |
+
## Ordinary function
|
316 |
+
def crop_and_resize(image: Image.Image,
|
317 |
+
target_width: int,
|
318 |
+
target_height: int) -> Image.Image:
|
319 |
+
"""
|
320 |
+
Crops and resizes an image while preserving the aspect ratio.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
324 |
+
target_width (int): Target width of the output image.
|
325 |
+
target_height (int): Target height of the output image.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
Image.Image: Cropped and resized image.
|
329 |
+
"""
|
330 |
+
# Original dimensions
|
331 |
+
original_width, original_height = image.size
|
332 |
+
original_aspect = original_width / original_height
|
333 |
+
target_aspect = target_width / target_height
|
334 |
+
|
335 |
+
# Calculate crop box to maintain aspect ratio
|
336 |
+
if original_aspect > target_aspect:
|
337 |
+
# Crop horizontally
|
338 |
+
new_width = int(original_height * target_aspect)
|
339 |
+
new_height = original_height
|
340 |
+
left = (original_width - new_width) / 2
|
341 |
+
top = 0
|
342 |
+
right = left + new_width
|
343 |
+
bottom = original_height
|
344 |
+
else:
|
345 |
+
# Crop vertically
|
346 |
+
new_width = original_width
|
347 |
+
new_height = int(original_width / target_aspect)
|
348 |
+
left = 0
|
349 |
+
top = (original_height - new_height) / 2
|
350 |
+
right = original_width
|
351 |
+
bottom = top + new_height
|
352 |
+
|
353 |
+
# Crop and resize
|
354 |
+
cropped_image = image.crop((left, top, right, bottom))
|
355 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
356 |
+
return resized_image
|
357 |
+
|
358 |
+
|
359 |
+
## Ordinary function
|
360 |
+
def resize(image: Image.Image,
|
361 |
+
target_width: int,
|
362 |
+
target_height: int) -> Image.Image:
|
363 |
+
"""
|
364 |
+
Crops and resizes an image while preserving the aspect ratio.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
368 |
+
target_width (int): Target width of the output image.
|
369 |
+
target_height (int): Target height of the output image.
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
Image.Image: Cropped and resized image.
|
373 |
+
"""
|
374 |
+
# Original dimensions
|
375 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
376 |
+
return resized_image
|
377 |
+
|
378 |
+
|
379 |
+
def move_mask_func(mask, direction, units):
|
380 |
+
binary_mask = mask.squeeze()>0
|
381 |
+
rows, cols = binary_mask.shape
|
382 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
383 |
+
|
384 |
+
if direction == 'down':
|
385 |
+
# move down
|
386 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
387 |
+
|
388 |
+
elif direction == 'up':
|
389 |
+
# move up
|
390 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
391 |
+
|
392 |
+
elif direction == 'right':
|
393 |
+
# move left
|
394 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
395 |
+
|
396 |
+
elif direction == 'left':
|
397 |
+
# move right
|
398 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
399 |
+
|
400 |
+
return moved_mask
|
401 |
+
|
402 |
+
|
403 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
404 |
+
# Randomly select the size of dilation
|
405 |
+
binary_mask = mask.squeeze()>0
|
406 |
+
|
407 |
+
if dilation_type == 'square_dilation':
|
408 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
409 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
410 |
+
elif dilation_type == 'square_erosion':
|
411 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
412 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
413 |
+
elif dilation_type == 'bounding_box':
|
414 |
+
# find the most left top and left bottom point
|
415 |
+
rows, cols = np.where(binary_mask)
|
416 |
+
if len(rows) == 0 or len(cols) == 0:
|
417 |
+
return mask # return original mask if no valid points
|
418 |
+
|
419 |
+
min_row = np.min(rows)
|
420 |
+
max_row = np.max(rows)
|
421 |
+
min_col = np.min(cols)
|
422 |
+
max_col = np.max(cols)
|
423 |
+
|
424 |
+
# create a bounding box
|
425 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
426 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
427 |
+
|
428 |
+
elif dilation_type == 'bounding_ellipse':
|
429 |
+
# find the most left top and left bottom point
|
430 |
+
rows, cols = np.where(binary_mask)
|
431 |
+
if len(rows) == 0 or len(cols) == 0:
|
432 |
+
return mask # return original mask if no valid points
|
433 |
+
|
434 |
+
min_row = np.min(rows)
|
435 |
+
max_row = np.max(rows)
|
436 |
+
min_col = np.min(cols)
|
437 |
+
max_col = np.max(cols)
|
438 |
+
|
439 |
+
# calculate the center and axis length of the ellipse
|
440 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
441 |
+
a = (max_col - min_col) // 2 # half long axis
|
442 |
+
b = (max_row - min_row) // 2 # half short axis
|
443 |
+
|
444 |
+
# create a bounding ellipse
|
445 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
446 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
447 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
448 |
+
dilated_mask[ellipse_mask] = True
|
449 |
+
else:
|
450 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
451 |
+
|
452 |
+
# use binary dilation
|
453 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
454 |
+
return dilated_mask
|
455 |
+
|
456 |
+
|
457 |
+
## Gradio component function
|
458 |
+
def update_vlm_model(vlm_name):
|
459 |
+
global vlm_model, vlm_processor
|
460 |
+
if vlm_model is not None:
|
461 |
+
del vlm_model
|
462 |
+
torch.cuda.empty_cache()
|
463 |
+
|
464 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
465 |
+
|
466 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
467 |
+
if vlm_type == "llava-next":
|
468 |
+
if vlm_processor != "" and vlm_model != "":
|
469 |
+
vlm_model.to(device)
|
470 |
+
return vlm_model_dropdown
|
471 |
+
else:
|
472 |
+
if os.path.exists(vlm_local_path):
|
473 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
474 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
475 |
+
else:
|
476 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
477 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
478 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
479 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
480 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
481 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
482 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
483 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
484 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
485 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
486 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
487 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
488 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
489 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
490 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
491 |
+
elif vlm_type == "qwen2-vl":
|
492 |
+
if vlm_processor != "" and vlm_model != "":
|
493 |
+
vlm_model.to(device)
|
494 |
+
return vlm_model_dropdown
|
495 |
+
else:
|
496 |
+
if os.path.exists(vlm_local_path):
|
497 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
498 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
499 |
+
else:
|
500 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
501 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
502 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
503 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
504 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
505 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
506 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
507 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
508 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
509 |
+
elif vlm_type == "openai":
|
510 |
+
pass
|
511 |
+
return "success"
|
512 |
+
|
513 |
+
|
514 |
+
def update_base_model(base_model_name):
|
515 |
+
global pipe
|
516 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
517 |
+
if pipe is not None:
|
518 |
+
del pipe
|
519 |
+
torch.cuda.empty_cache()
|
520 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
521 |
+
if pipe != "":
|
522 |
+
pipe.to(device)
|
523 |
+
else:
|
524 |
+
if os.path.exists(base_model_path):
|
525 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
526 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
527 |
+
)
|
528 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
529 |
+
pipe.enable_model_cpu_offload()
|
530 |
+
else:
|
531 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
532 |
+
return "success"
|
533 |
+
|
534 |
+
|
535 |
+
def process_random_mask(input_image,
|
536 |
+
original_image,
|
537 |
+
original_mask,
|
538 |
+
resize_default,
|
539 |
+
aspect_ratio_name,
|
540 |
+
):
|
541 |
+
|
542 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
543 |
+
input_mask = np.asarray(alpha_mask)
|
544 |
+
|
545 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
546 |
+
if output_w == "" or output_h == "":
|
547 |
+
output_h, output_w = original_image.shape[:2]
|
548 |
+
if resize_default:
|
549 |
+
short_side = min(output_w, output_h)
|
550 |
+
scale_ratio = 640 / short_side
|
551 |
+
output_w = int(output_w * scale_ratio)
|
552 |
+
output_h = int(output_h * scale_ratio)
|
553 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
554 |
+
original_image = np.array(original_image)
|
555 |
+
if input_mask is not None:
|
556 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
557 |
+
input_mask = np.array(input_mask)
|
558 |
+
if original_mask is not None:
|
559 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
560 |
+
original_mask = np.array(original_mask)
|
561 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
562 |
+
else:
|
563 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
564 |
+
pass
|
565 |
+
else:
|
566 |
+
if resize_default:
|
567 |
+
short_side = min(output_w, output_h)
|
568 |
+
scale_ratio = 640 / short_side
|
569 |
+
output_w = int(output_w * scale_ratio)
|
570 |
+
output_h = int(output_h * scale_ratio)
|
571 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
572 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
573 |
+
original_image = np.array(original_image)
|
574 |
+
if input_mask is not None:
|
575 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
576 |
+
input_mask = np.array(input_mask)
|
577 |
+
if original_mask is not None:
|
578 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
579 |
+
original_mask = np.array(original_mask)
|
580 |
+
|
581 |
+
|
582 |
+
if input_mask.max() == 0:
|
583 |
+
original_mask = original_mask
|
584 |
+
else:
|
585 |
+
original_mask = input_mask
|
586 |
+
|
587 |
+
if original_mask is None:
|
588 |
+
raise gr.Error('Please generate mask first')
|
589 |
+
|
590 |
+
if original_mask.ndim == 2:
|
591 |
+
original_mask = original_mask[:,:,None]
|
592 |
+
|
593 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
594 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
595 |
+
|
596 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
597 |
+
|
598 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
599 |
+
masked_image = masked_image.astype(original_image.dtype)
|
600 |
+
masked_image = Image.fromarray(masked_image)
|
601 |
+
|
602 |
+
|
603 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
604 |
+
|
605 |
+
|
606 |
+
def process_dilation_mask(input_image,
|
607 |
+
original_image,
|
608 |
+
original_mask,
|
609 |
+
resize_default,
|
610 |
+
aspect_ratio_name,
|
611 |
+
dilation_size=20):
|
612 |
+
|
613 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
614 |
+
input_mask = np.asarray(alpha_mask)
|
615 |
+
|
616 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
617 |
+
if output_w == "" or output_h == "":
|
618 |
+
output_h, output_w = original_image.shape[:2]
|
619 |
+
if resize_default:
|
620 |
+
short_side = min(output_w, output_h)
|
621 |
+
scale_ratio = 640 / short_side
|
622 |
+
output_w = int(output_w * scale_ratio)
|
623 |
+
output_h = int(output_h * scale_ratio)
|
624 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
625 |
+
original_image = np.array(original_image)
|
626 |
+
if input_mask is not None:
|
627 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
628 |
+
input_mask = np.array(input_mask)
|
629 |
+
if original_mask is not None:
|
630 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
631 |
+
original_mask = np.array(original_mask)
|
632 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
633 |
+
else:
|
634 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
635 |
+
pass
|
636 |
+
else:
|
637 |
+
if resize_default:
|
638 |
+
short_side = min(output_w, output_h)
|
639 |
+
scale_ratio = 640 / short_side
|
640 |
+
output_w = int(output_w * scale_ratio)
|
641 |
+
output_h = int(output_h * scale_ratio)
|
642 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
643 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
644 |
+
original_image = np.array(original_image)
|
645 |
+
if input_mask is not None:
|
646 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
647 |
+
input_mask = np.array(input_mask)
|
648 |
+
if original_mask is not None:
|
649 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
650 |
+
original_mask = np.array(original_mask)
|
651 |
+
|
652 |
+
if input_mask.max() == 0:
|
653 |
+
original_mask = original_mask
|
654 |
+
else:
|
655 |
+
original_mask = input_mask
|
656 |
+
|
657 |
+
if original_mask is None:
|
658 |
+
raise gr.Error('Please generate mask first')
|
659 |
+
|
660 |
+
if original_mask.ndim == 2:
|
661 |
+
original_mask = original_mask[:,:,None]
|
662 |
+
|
663 |
+
dilation_type = np.random.choice(['square_dilation'])
|
664 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
665 |
+
|
666 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
667 |
+
|
668 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
669 |
+
masked_image = masked_image.astype(original_image.dtype)
|
670 |
+
masked_image = Image.fromarray(masked_image)
|
671 |
+
|
672 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
673 |
+
|
674 |
+
|
675 |
+
def process_erosion_mask(input_image,
|
676 |
+
original_image,
|
677 |
+
original_mask,
|
678 |
+
resize_default,
|
679 |
+
aspect_ratio_name,
|
680 |
+
dilation_size=20):
|
681 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
682 |
+
input_mask = np.asarray(alpha_mask)
|
683 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
684 |
+
if output_w == "" or output_h == "":
|
685 |
+
output_h, output_w = original_image.shape[:2]
|
686 |
+
if resize_default:
|
687 |
+
short_side = min(output_w, output_h)
|
688 |
+
scale_ratio = 640 / short_side
|
689 |
+
output_w = int(output_w * scale_ratio)
|
690 |
+
output_h = int(output_h * scale_ratio)
|
691 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
692 |
+
original_image = np.array(original_image)
|
693 |
+
if input_mask is not None:
|
694 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
695 |
+
input_mask = np.array(input_mask)
|
696 |
+
if original_mask is not None:
|
697 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
698 |
+
original_mask = np.array(original_mask)
|
699 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
700 |
+
else:
|
701 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
702 |
+
pass
|
703 |
+
else:
|
704 |
+
if resize_default:
|
705 |
+
short_side = min(output_w, output_h)
|
706 |
+
scale_ratio = 640 / short_side
|
707 |
+
output_w = int(output_w * scale_ratio)
|
708 |
+
output_h = int(output_h * scale_ratio)
|
709 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
710 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
711 |
+
original_image = np.array(original_image)
|
712 |
+
if input_mask is not None:
|
713 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
714 |
+
input_mask = np.array(input_mask)
|
715 |
+
if original_mask is not None:
|
716 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
717 |
+
original_mask = np.array(original_mask)
|
718 |
+
|
719 |
+
if input_mask.max() == 0:
|
720 |
+
original_mask = original_mask
|
721 |
+
else:
|
722 |
+
original_mask = input_mask
|
723 |
+
|
724 |
+
if original_mask is None:
|
725 |
+
raise gr.Error('Please generate mask first')
|
726 |
+
|
727 |
+
if original_mask.ndim == 2:
|
728 |
+
original_mask = original_mask[:,:,None]
|
729 |
+
|
730 |
+
dilation_type = np.random.choice(['square_erosion'])
|
731 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
732 |
+
|
733 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
734 |
+
|
735 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
736 |
+
masked_image = masked_image.astype(original_image.dtype)
|
737 |
+
masked_image = Image.fromarray(masked_image)
|
738 |
+
|
739 |
+
|
740 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
741 |
+
|
742 |
+
|
743 |
+
def move_mask_left(input_image,
|
744 |
+
original_image,
|
745 |
+
original_mask,
|
746 |
+
moving_pixels,
|
747 |
+
resize_default,
|
748 |
+
aspect_ratio_name):
|
749 |
+
|
750 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
751 |
+
input_mask = np.asarray(alpha_mask)
|
752 |
+
|
753 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
754 |
+
if output_w == "" or output_h == "":
|
755 |
+
output_h, output_w = original_image.shape[:2]
|
756 |
+
if resize_default:
|
757 |
+
short_side = min(output_w, output_h)
|
758 |
+
scale_ratio = 640 / short_side
|
759 |
+
output_w = int(output_w * scale_ratio)
|
760 |
+
output_h = int(output_h * scale_ratio)
|
761 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
762 |
+
original_image = np.array(original_image)
|
763 |
+
if input_mask is not None:
|
764 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
765 |
+
input_mask = np.array(input_mask)
|
766 |
+
if original_mask is not None:
|
767 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
768 |
+
original_mask = np.array(original_mask)
|
769 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
770 |
+
else:
|
771 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
772 |
+
pass
|
773 |
+
else:
|
774 |
+
if resize_default:
|
775 |
+
short_side = min(output_w, output_h)
|
776 |
+
scale_ratio = 640 / short_side
|
777 |
+
output_w = int(output_w * scale_ratio)
|
778 |
+
output_h = int(output_h * scale_ratio)
|
779 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
780 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
781 |
+
original_image = np.array(original_image)
|
782 |
+
if input_mask is not None:
|
783 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
784 |
+
input_mask = np.array(input_mask)
|
785 |
+
if original_mask is not None:
|
786 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
787 |
+
original_mask = np.array(original_mask)
|
788 |
+
|
789 |
+
if input_mask.max() == 0:
|
790 |
+
original_mask = original_mask
|
791 |
+
else:
|
792 |
+
original_mask = input_mask
|
793 |
+
|
794 |
+
if original_mask is None:
|
795 |
+
raise gr.Error('Please generate mask first')
|
796 |
+
|
797 |
+
if original_mask.ndim == 2:
|
798 |
+
original_mask = original_mask[:,:,None]
|
799 |
+
|
800 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
801 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
802 |
+
|
803 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
804 |
+
masked_image = masked_image.astype(original_image.dtype)
|
805 |
+
masked_image = Image.fromarray(masked_image)
|
806 |
+
|
807 |
+
if moved_mask.max() <= 1:
|
808 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
809 |
+
original_mask = moved_mask
|
810 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
811 |
+
|
812 |
+
|
813 |
+
def move_mask_right(input_image,
|
814 |
+
original_image,
|
815 |
+
original_mask,
|
816 |
+
moving_pixels,
|
817 |
+
resize_default,
|
818 |
+
aspect_ratio_name):
|
819 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
820 |
+
input_mask = np.asarray(alpha_mask)
|
821 |
+
|
822 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
823 |
+
if output_w == "" or output_h == "":
|
824 |
+
output_h, output_w = original_image.shape[:2]
|
825 |
+
if resize_default:
|
826 |
+
short_side = min(output_w, output_h)
|
827 |
+
scale_ratio = 640 / short_side
|
828 |
+
output_w = int(output_w * scale_ratio)
|
829 |
+
output_h = int(output_h * scale_ratio)
|
830 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
831 |
+
original_image = np.array(original_image)
|
832 |
+
if input_mask is not None:
|
833 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
834 |
+
input_mask = np.array(input_mask)
|
835 |
+
if original_mask is not None:
|
836 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
837 |
+
original_mask = np.array(original_mask)
|
838 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
839 |
+
else:
|
840 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
841 |
+
pass
|
842 |
+
else:
|
843 |
+
if resize_default:
|
844 |
+
short_side = min(output_w, output_h)
|
845 |
+
scale_ratio = 640 / short_side
|
846 |
+
output_w = int(output_w * scale_ratio)
|
847 |
+
output_h = int(output_h * scale_ratio)
|
848 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
849 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
850 |
+
original_image = np.array(original_image)
|
851 |
+
if input_mask is not None:
|
852 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
853 |
+
input_mask = np.array(input_mask)
|
854 |
+
if original_mask is not None:
|
855 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
856 |
+
original_mask = np.array(original_mask)
|
857 |
+
|
858 |
+
if input_mask.max() == 0:
|
859 |
+
original_mask = original_mask
|
860 |
+
else:
|
861 |
+
original_mask = input_mask
|
862 |
+
|
863 |
+
if original_mask is None:
|
864 |
+
raise gr.Error('Please generate mask first')
|
865 |
+
|
866 |
+
if original_mask.ndim == 2:
|
867 |
+
original_mask = original_mask[:,:,None]
|
868 |
+
|
869 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
870 |
+
|
871 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
872 |
+
|
873 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
874 |
+
masked_image = masked_image.astype(original_image.dtype)
|
875 |
+
masked_image = Image.fromarray(masked_image)
|
876 |
+
|
877 |
+
|
878 |
+
if moved_mask.max() <= 1:
|
879 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
880 |
+
original_mask = moved_mask
|
881 |
+
|
882 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
883 |
+
|
884 |
+
|
885 |
+
def move_mask_up(input_image,
|
886 |
+
original_image,
|
887 |
+
original_mask,
|
888 |
+
moving_pixels,
|
889 |
+
resize_default,
|
890 |
+
aspect_ratio_name):
|
891 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
892 |
+
input_mask = np.asarray(alpha_mask)
|
893 |
+
|
894 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
895 |
+
if output_w == "" or output_h == "":
|
896 |
+
output_h, output_w = original_image.shape[:2]
|
897 |
+
if resize_default:
|
898 |
+
short_side = min(output_w, output_h)
|
899 |
+
scale_ratio = 640 / short_side
|
900 |
+
output_w = int(output_w * scale_ratio)
|
901 |
+
output_h = int(output_h * scale_ratio)
|
902 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
903 |
+
original_image = np.array(original_image)
|
904 |
+
if input_mask is not None:
|
905 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
906 |
+
input_mask = np.array(input_mask)
|
907 |
+
if original_mask is not None:
|
908 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
909 |
+
original_mask = np.array(original_mask)
|
910 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
911 |
+
else:
|
912 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
913 |
+
pass
|
914 |
+
else:
|
915 |
+
if resize_default:
|
916 |
+
short_side = min(output_w, output_h)
|
917 |
+
scale_ratio = 640 / short_side
|
918 |
+
output_w = int(output_w * scale_ratio)
|
919 |
+
output_h = int(output_h * scale_ratio)
|
920 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
921 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
922 |
+
original_image = np.array(original_image)
|
923 |
+
if input_mask is not None:
|
924 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
925 |
+
input_mask = np.array(input_mask)
|
926 |
+
if original_mask is not None:
|
927 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
928 |
+
original_mask = np.array(original_mask)
|
929 |
+
|
930 |
+
if input_mask.max() == 0:
|
931 |
+
original_mask = original_mask
|
932 |
+
else:
|
933 |
+
original_mask = input_mask
|
934 |
+
|
935 |
+
if original_mask is None:
|
936 |
+
raise gr.Error('Please generate mask first')
|
937 |
+
|
938 |
+
if original_mask.ndim == 2:
|
939 |
+
original_mask = original_mask[:,:,None]
|
940 |
+
|
941 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
942 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
943 |
+
|
944 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
945 |
+
masked_image = masked_image.astype(original_image.dtype)
|
946 |
+
masked_image = Image.fromarray(masked_image)
|
947 |
+
|
948 |
+
if moved_mask.max() <= 1:
|
949 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
950 |
+
original_mask = moved_mask
|
951 |
+
|
952 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
953 |
+
|
954 |
+
|
955 |
+
def move_mask_down(input_image,
|
956 |
+
original_image,
|
957 |
+
original_mask,
|
958 |
+
moving_pixels,
|
959 |
+
resize_default,
|
960 |
+
aspect_ratio_name):
|
961 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
962 |
+
input_mask = np.asarray(alpha_mask)
|
963 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
964 |
+
if output_w == "" or output_h == "":
|
965 |
+
output_h, output_w = original_image.shape[:2]
|
966 |
+
if resize_default:
|
967 |
+
short_side = min(output_w, output_h)
|
968 |
+
scale_ratio = 640 / short_side
|
969 |
+
output_w = int(output_w * scale_ratio)
|
970 |
+
output_h = int(output_h * scale_ratio)
|
971 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
972 |
+
original_image = np.array(original_image)
|
973 |
+
if input_mask is not None:
|
974 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
975 |
+
input_mask = np.array(input_mask)
|
976 |
+
if original_mask is not None:
|
977 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
978 |
+
original_mask = np.array(original_mask)
|
979 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
980 |
+
else:
|
981 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
982 |
+
pass
|
983 |
+
else:
|
984 |
+
if resize_default:
|
985 |
+
short_side = min(output_w, output_h)
|
986 |
+
scale_ratio = 640 / short_side
|
987 |
+
output_w = int(output_w * scale_ratio)
|
988 |
+
output_h = int(output_h * scale_ratio)
|
989 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
990 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
991 |
+
original_image = np.array(original_image)
|
992 |
+
if input_mask is not None:
|
993 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
994 |
+
input_mask = np.array(input_mask)
|
995 |
+
if original_mask is not None:
|
996 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
997 |
+
original_mask = np.array(original_mask)
|
998 |
+
|
999 |
+
if input_mask.max() == 0:
|
1000 |
+
original_mask = original_mask
|
1001 |
+
else:
|
1002 |
+
original_mask = input_mask
|
1003 |
+
|
1004 |
+
if original_mask is None:
|
1005 |
+
raise gr.Error('Please generate mask first')
|
1006 |
+
|
1007 |
+
if original_mask.ndim == 2:
|
1008 |
+
original_mask = original_mask[:,:,None]
|
1009 |
+
|
1010 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1011 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1012 |
+
|
1013 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1014 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1015 |
+
masked_image = Image.fromarray(masked_image)
|
1016 |
+
|
1017 |
+
if moved_mask.max() <= 1:
|
1018 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1019 |
+
original_mask = moved_mask
|
1020 |
+
|
1021 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1022 |
+
|
1023 |
+
|
1024 |
+
def invert_mask(input_image,
|
1025 |
+
original_image,
|
1026 |
+
original_mask,
|
1027 |
+
):
|
1028 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1029 |
+
input_mask = np.asarray(alpha_mask)
|
1030 |
+
if input_mask.max() == 0:
|
1031 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1032 |
+
else:
|
1033 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1034 |
+
|
1035 |
+
if original_mask is None:
|
1036 |
+
raise gr.Error('Please generate mask first')
|
1037 |
+
|
1038 |
+
original_mask = original_mask.squeeze()
|
1039 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1040 |
+
|
1041 |
+
if original_mask.ndim == 2:
|
1042 |
+
original_mask = original_mask[:,:,None]
|
1043 |
+
|
1044 |
+
if original_mask.max() <= 1:
|
1045 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1046 |
+
|
1047 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1048 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1049 |
+
masked_image = Image.fromarray(masked_image)
|
1050 |
+
|
1051 |
+
return [masked_image], [mask_image], original_mask, True
|
1052 |
+
|
1053 |
+
|
1054 |
+
|
1055 |
+
def reset_func(input_image,
|
1056 |
+
original_image,
|
1057 |
+
original_mask,
|
1058 |
+
prompt,
|
1059 |
+
target_prompt,
|
1060 |
+
):
|
1061 |
+
input_image = None
|
1062 |
+
original_image = None
|
1063 |
+
original_mask = None
|
1064 |
+
prompt = ''
|
1065 |
+
mask_gallery = []
|
1066 |
+
masked_gallery = []
|
1067 |
+
result_gallery = []
|
1068 |
+
target_prompt = ''
|
1069 |
+
if torch.cuda.is_available():
|
1070 |
+
torch.cuda.empty_cache()
|
1071 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1072 |
+
|
1073 |
+
|
1074 |
+
def update_example(example_type,
|
1075 |
+
prompt,
|
1076 |
+
example_change_times):
|
1077 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1078 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1079 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1080 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1081 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1082 |
+
width, height = image_pil.size
|
1083 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1084 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1085 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1086 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1087 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1088 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1089 |
+
|
1090 |
+
original_image = np.array(image_pil)
|
1091 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1092 |
+
aspect_ratio = "Custom resolution"
|
1093 |
+
example_change_times += 1
|
1094 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1095 |
+
|
1096 |
+
|
1097 |
+
def generate_target_prompt(input_image,
|
1098 |
+
original_image,
|
1099 |
+
prompt):
|
1100 |
+
# load example image
|
1101 |
+
if isinstance(original_image, str):
|
1102 |
+
original_image = input_image
|
1103 |
+
|
1104 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1105 |
+
vlm_processor,
|
1106 |
+
vlm_model,
|
1107 |
+
original_image,
|
1108 |
+
prompt,
|
1109 |
+
device)
|
1110 |
+
return prompt_after_apply_instruction
|
1111 |
+
|
1112 |
+
|
1113 |
+
|
1114 |
+
from app.utils.utils import generate_caption
|
1115 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
1116 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1117 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
|
1118 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
|
1119 |
+
|
1120 |
+
|
1121 |
+
def init_img(base,
|
1122 |
+
init_type,
|
1123 |
+
prompt,
|
1124 |
+
aspect_ratio,
|
1125 |
+
example_change_times
|
1126 |
+
):
|
1127 |
+
image_pil = base["background"].convert("RGB")
|
1128 |
+
original_image = np.array(image_pil)
|
1129 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1130 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1131 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1132 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1133 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1134 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1135 |
+
width, height = image_pil.size
|
1136 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1137 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1138 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1139 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1140 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1141 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1142 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1143 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1144 |
+
else:
|
1145 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1146 |
+
aspect_ratio = "Custom resolution"
|
1147 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1148 |
+
|
1149 |
+
|
1150 |
+
def process_mask(input_image,
|
1151 |
+
original_image,
|
1152 |
+
prompt,
|
1153 |
+
resize_default,
|
1154 |
+
aspect_ratio_name):
|
1155 |
+
if original_image is None:
|
1156 |
+
raise gr.Error('Please upload the input image')
|
1157 |
+
if prompt is None:
|
1158 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
1159 |
+
|
1160 |
+
## load mask
|
1161 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1162 |
+
input_mask = np.array(alpha_mask)
|
1163 |
+
|
1164 |
+
# load example image
|
1165 |
+
if isinstance(original_image, str):
|
1166 |
+
original_image = input_image["background"]
|
1167 |
+
|
1168 |
+
if input_mask.max() == 0:
|
1169 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
1170 |
+
|
1171 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
1172 |
+
vlm_model,
|
1173 |
+
original_image,
|
1174 |
+
category,
|
1175 |
+
prompt,
|
1176 |
+
device)
|
1177 |
+
# original mask: h,w,1 [0, 255]
|
1178 |
+
original_mask = vlm_response_mask(
|
1179 |
+
vlm_processor,
|
1180 |
+
vlm_model,
|
1181 |
+
category,
|
1182 |
+
original_image,
|
1183 |
+
prompt,
|
1184 |
+
object_wait_for_edit,
|
1185 |
+
sam,
|
1186 |
+
sam_predictor,
|
1187 |
+
sam_automask_generator,
|
1188 |
+
groundingdino_model,
|
1189 |
+
device).astype(np.uint8)
|
1190 |
+
else:
|
1191 |
+
original_mask = input_mask.astype(np.uint8)
|
1192 |
+
category = None
|
1193 |
+
|
1194 |
+
## resize mask if needed
|
1195 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1196 |
+
if output_w == "" or output_h == "":
|
1197 |
+
output_h, output_w = original_image.shape[:2]
|
1198 |
+
if resize_default:
|
1199 |
+
short_side = min(output_w, output_h)
|
1200 |
+
scale_ratio = 640 / short_side
|
1201 |
+
output_w = int(output_w * scale_ratio)
|
1202 |
+
output_h = int(output_h * scale_ratio)
|
1203 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1204 |
+
original_image = np.array(original_image)
|
1205 |
+
if input_mask is not None:
|
1206 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1207 |
+
input_mask = np.array(input_mask)
|
1208 |
+
if original_mask is not None:
|
1209 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1210 |
+
original_mask = np.array(original_mask)
|
1211 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1212 |
+
else:
|
1213 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1214 |
+
pass
|
1215 |
+
else:
|
1216 |
+
if resize_default:
|
1217 |
+
short_side = min(output_w, output_h)
|
1218 |
+
scale_ratio = 640 / short_side
|
1219 |
+
output_w = int(output_w * scale_ratio)
|
1220 |
+
output_h = int(output_h * scale_ratio)
|
1221 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1222 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1223 |
+
original_image = np.array(original_image)
|
1224 |
+
if input_mask is not None:
|
1225 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1226 |
+
input_mask = np.array(input_mask)
|
1227 |
+
if original_mask is not None:
|
1228 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1229 |
+
original_mask = np.array(original_mask)
|
1230 |
+
|
1231 |
+
|
1232 |
+
if original_mask.ndim == 2:
|
1233 |
+
original_mask = original_mask[:,:,None]
|
1234 |
+
|
1235 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
1236 |
+
|
1237 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1238 |
+
masked_image = masked_image.astype(np.uint8)
|
1239 |
+
masked_image = Image.fromarray(masked_image)
|
1240 |
+
|
1241 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
1242 |
+
|
1243 |
+
|
1244 |
+
def process(input_image,
|
1245 |
+
original_image,
|
1246 |
+
original_mask,
|
1247 |
+
prompt,
|
1248 |
+
negative_prompt,
|
1249 |
+
control_strength,
|
1250 |
+
seed,
|
1251 |
+
randomize_seed,
|
1252 |
+
guidance_scale,
|
1253 |
+
num_inference_steps,
|
1254 |
+
num_samples,
|
1255 |
+
blending,
|
1256 |
+
category,
|
1257 |
+
target_prompt,
|
1258 |
+
resize_default,
|
1259 |
+
aspect_ratio_name,
|
1260 |
+
invert_mask_state):
|
1261 |
+
if original_image is None:
|
1262 |
+
if input_image is None:
|
1263 |
+
raise gr.Error('Please upload the input image')
|
1264 |
+
else:
|
1265 |
+
image_pil = input_image["background"].convert("RGB")
|
1266 |
+
original_image = np.array(image_pil)
|
1267 |
+
if prompt is None or prompt == "":
|
1268 |
+
if target_prompt is None or target_prompt == "":
|
1269 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
1270 |
+
|
1271 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1272 |
+
input_mask = np.asarray(alpha_mask)
|
1273 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1274 |
+
if output_w == "" or output_h == "":
|
1275 |
+
output_h, output_w = original_image.shape[:2]
|
1276 |
+
|
1277 |
+
if resize_default:
|
1278 |
+
short_side = min(output_w, output_h)
|
1279 |
+
scale_ratio = 640 / short_side
|
1280 |
+
output_w = int(output_w * scale_ratio)
|
1281 |
+
output_h = int(output_h * scale_ratio)
|
1282 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1283 |
+
original_image = np.array(original_image)
|
1284 |
+
if input_mask is not None:
|
1285 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1286 |
+
input_mask = np.array(input_mask)
|
1287 |
+
if original_mask is not None:
|
1288 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1289 |
+
original_mask = np.array(original_mask)
|
1290 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1291 |
+
else:
|
1292 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1293 |
+
pass
|
1294 |
+
else:
|
1295 |
+
if resize_default:
|
1296 |
+
short_side = min(output_w, output_h)
|
1297 |
+
scale_ratio = 640 / short_side
|
1298 |
+
output_w = int(output_w * scale_ratio)
|
1299 |
+
output_h = int(output_h * scale_ratio)
|
1300 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1301 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1302 |
+
original_image = np.array(original_image)
|
1303 |
+
if input_mask is not None:
|
1304 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1305 |
+
input_mask = np.array(input_mask)
|
1306 |
+
if original_mask is not None:
|
1307 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1308 |
+
original_mask = np.array(original_mask)
|
1309 |
+
|
1310 |
+
if invert_mask_state:
|
1311 |
+
original_mask = original_mask
|
1312 |
+
else:
|
1313 |
+
if input_mask.max() == 0:
|
1314 |
+
original_mask = original_mask
|
1315 |
+
else:
|
1316 |
+
original_mask = input_mask
|
1317 |
+
|
1318 |
+
|
1319 |
+
# inpainting directly if target_prompt is not None
|
1320 |
+
if category is not None:
|
1321 |
+
pass
|
1322 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
1323 |
+
pass
|
1324 |
+
else:
|
1325 |
+
try:
|
1326 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
1327 |
+
except Exception as e:
|
1328 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1329 |
+
|
1330 |
+
|
1331 |
+
if original_mask is not None:
|
1332 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
1333 |
+
else:
|
1334 |
+
try:
|
1335 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
1336 |
+
vlm_processor,
|
1337 |
+
vlm_model,
|
1338 |
+
original_image,
|
1339 |
+
category,
|
1340 |
+
prompt,
|
1341 |
+
device)
|
1342 |
+
|
1343 |
+
original_mask = vlm_response_mask(vlm_processor,
|
1344 |
+
vlm_model,
|
1345 |
+
category,
|
1346 |
+
original_image,
|
1347 |
+
prompt,
|
1348 |
+
object_wait_for_edit,
|
1349 |
+
sam,
|
1350 |
+
sam_predictor,
|
1351 |
+
sam_automask_generator,
|
1352 |
+
groundingdino_model,
|
1353 |
+
device).astype(np.uint8)
|
1354 |
+
except Exception as e:
|
1355 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1356 |
+
|
1357 |
+
if original_mask.ndim == 2:
|
1358 |
+
original_mask = original_mask[:,:,None]
|
1359 |
+
|
1360 |
+
|
1361 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
1362 |
+
prompt_after_apply_instruction = target_prompt
|
1363 |
+
|
1364 |
+
else:
|
1365 |
+
try:
|
1366 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1367 |
+
vlm_processor,
|
1368 |
+
vlm_model,
|
1369 |
+
original_image,
|
1370 |
+
prompt,
|
1371 |
+
device)
|
1372 |
+
except Exception as e:
|
1373 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1374 |
+
|
1375 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
1376 |
+
|
1377 |
+
|
1378 |
+
with torch.autocast(device):
|
1379 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
1380 |
+
prompt_after_apply_instruction,
|
1381 |
+
original_mask,
|
1382 |
+
original_image,
|
1383 |
+
generator,
|
1384 |
+
num_inference_steps,
|
1385 |
+
guidance_scale,
|
1386 |
+
control_strength,
|
1387 |
+
negative_prompt,
|
1388 |
+
num_samples,
|
1389 |
+
blending)
|
1390 |
+
original_image = np.array(init_image_np)
|
1391 |
+
masked_image = original_image * (1 - (mask_np>0))
|
1392 |
+
masked_image = masked_image.astype(np.uint8)
|
1393 |
+
masked_image = Image.fromarray(masked_image)
|
1394 |
+
|
1395 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
1396 |
+
|
1397 |
+
|
1398 |
+
def process_cirr_images():
|
1399 |
+
# 初始化VLM/SAM模型(需补充实际加载代码)
|
1400 |
+
global vlm_model, sam_predictor, groundingdino_model
|
1401 |
+
if not all([vlm_model, sam_predictor, groundingdino_model]):
|
1402 |
+
raise RuntimeError("Required models not initialized")
|
1403 |
+
|
1404 |
+
# Define paths
|
1405 |
+
dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
|
1406 |
+
cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
|
1407 |
+
output_dirs = {
|
1408 |
+
"edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_edited"),
|
1409 |
+
"mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_mask"),
|
1410 |
+
"masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_masked")
|
1411 |
+
}
|
1412 |
+
|
1413 |
+
# Create output directories
|
1414 |
+
for dir_path in output_dirs.values():
|
1415 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
1416 |
+
|
1417 |
+
# Load captions
|
1418 |
+
with open(cap_file, 'r') as f:
|
1419 |
+
captions = json.load(f)
|
1420 |
+
|
1421 |
+
descriptions = {}
|
1422 |
+
|
1423 |
+
for img_path in dev_dir.glob("*.png"):
|
1424 |
+
base_name = img_path.stem
|
1425 |
+
caption = next((item["caption"] for item in captions if item.get("reference") == base_name), None)
|
1426 |
+
|
1427 |
+
if not caption:
|
1428 |
+
print(f"Warning: No caption for {base_name}")
|
1429 |
+
continue
|
1430 |
+
|
1431 |
+
try:
|
1432 |
+
# 关键修改1:构造空alpha通道(全0)
|
1433 |
+
rgb_image = Image.open(img_path).convert("RGB")
|
1434 |
+
empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
|
1435 |
+
image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
|
1436 |
+
|
1437 |
+
# 关键修改2:调用init_img初始化
|
1438 |
+
base = {"background": image, "layers": [image]}
|
1439 |
+
init_results = init_img(
|
1440 |
+
base=base,
|
1441 |
+
init_type="custom", # 使用自定义初始化
|
1442 |
+
prompt=caption,
|
1443 |
+
aspect_ratio="Custom resolution",
|
1444 |
+
example_change_times=0
|
1445 |
+
)
|
1446 |
+
|
1447 |
+
# 获取初始化后的参数
|
1448 |
+
input_image = init_results[0]
|
1449 |
+
original_image = init_results[1]
|
1450 |
+
original_mask = init_results[2]
|
1451 |
+
|
1452 |
+
# 关键修改3:正确设置process参数
|
1453 |
+
process_results = process(
|
1454 |
+
input_image=input_image,
|
1455 |
+
original_image=original_image,
|
1456 |
+
original_mask=original_mask, # 传递初始化后的mask
|
1457 |
+
prompt=caption,
|
1458 |
+
negative_prompt="ugly, low quality",
|
1459 |
+
control_strength=1.0,
|
1460 |
+
seed=648464818,
|
1461 |
+
randomize_seed=False,
|
1462 |
+
guidance_scale=7.5,
|
1463 |
+
num_inference_steps=50,
|
1464 |
+
num_samples=1,
|
1465 |
+
blending=True,
|
1466 |
+
category=None,
|
1467 |
+
target_prompt="",
|
1468 |
+
resize_default=True,
|
1469 |
+
aspect_ratio_name="Custom resolution",
|
1470 |
+
invert_mask_state=False
|
1471 |
+
)
|
1472 |
+
|
1473 |
+
# 结果处理(保持原有逻辑)
|
1474 |
+
result_images, mask_images, masked_images = process_results[:3]
|
1475 |
+
|
1476 |
+
# Save images
|
1477 |
+
output_dirs["edited"].mkdir(exist_ok=True)
|
1478 |
+
result_images[0].save(output_dirs["edited"] / f"{base_name}.png")
|
1479 |
+
mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.png")
|
1480 |
+
masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.png")
|
1481 |
+
|
1482 |
+
# Generate BLIP description
|
1483 |
+
blip_desc, _ = generate_blip_description({"background": image})
|
1484 |
+
descriptions[base_name] = {
|
1485 |
+
"original_caption": caption,
|
1486 |
+
"blip_description": blip_desc
|
1487 |
+
}
|
1488 |
+
|
1489 |
+
print(f"Processed {base_name}")
|
1490 |
+
|
1491 |
+
except Exception as e:
|
1492 |
+
print(f"Error processing {base_name}: {str(e)}")
|
1493 |
+
continue
|
1494 |
+
|
1495 |
+
# Save descriptions
|
1496 |
+
with open("/home/zt/data/BrushEdit/cirr/cirr_description_fix.json", 'w') as f:
|
1497 |
+
json.dump(descriptions, f, indent=4)
|
1498 |
+
|
1499 |
+
print("Processing completed!")
|
1500 |
+
|
1501 |
+
|
1502 |
+
# def process_cirr_images():
|
1503 |
+
# # Define paths
|
1504 |
+
# dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
|
1505 |
+
# cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
|
1506 |
+
# output_dirs = {
|
1507 |
+
# "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_edited"),
|
1508 |
+
# "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_mask"),
|
1509 |
+
# "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_fix/cirr_masked")
|
1510 |
+
# }
|
1511 |
+
|
1512 |
+
# # Create output directories if they don't exist
|
1513 |
+
# for dir_path in output_dirs.values():
|
1514 |
+
# dir_path.mkdir(parents=True, exist_ok=True)
|
1515 |
+
|
1516 |
+
# # Load captions from JSON file
|
1517 |
+
# with open(cap_file, 'r') as f:
|
1518 |
+
# captions = json.load(f)
|
1519 |
+
|
1520 |
+
# # Initialize description dictionary
|
1521 |
+
# descriptions = {}
|
1522 |
+
|
1523 |
+
# # Process each PNG image in dev directory
|
1524 |
+
# for img_path in dev_dir.glob("*.png"):
|
1525 |
+
# # Get base name without extension
|
1526 |
+
# base_name = img_path.stem
|
1527 |
+
|
1528 |
+
# # Find matching caption
|
1529 |
+
# caption = None
|
1530 |
+
# for item in captions:
|
1531 |
+
# if item.get("reference") == base_name:
|
1532 |
+
# caption = item.get("caption")
|
1533 |
+
# break
|
1534 |
+
|
1535 |
+
# if caption is None:
|
1536 |
+
# print(f"Warning: No caption found for {base_name}")
|
1537 |
+
# continue
|
1538 |
+
|
1539 |
+
# # Load and convert image to RGB
|
1540 |
+
# try:
|
1541 |
+
# rgb_image = Image.open(img_path).convert("RGB")
|
1542 |
+
# a = Image.new("L", rgb_image.size, 255) # 全不透明alpha通道
|
1543 |
+
# image = Image.merge("RGBA", (*rgb_image.split(), a))
|
1544 |
+
# except Exception as e:
|
1545 |
+
# print(f"Error loading image {img_path}: {e}")
|
1546 |
+
# continue
|
1547 |
+
|
1548 |
+
# # Generate BLIP description
|
1549 |
+
# try:
|
1550 |
+
# blip_desc, _ = generate_blip_description({"background": image})
|
1551 |
+
# except Exception as e:
|
1552 |
+
# print(f"Error generating BLIP description for {base_name}: {e}")
|
1553 |
+
# continue
|
1554 |
+
|
1555 |
+
# # Process image
|
1556 |
+
# try:
|
1557 |
+
# # Prepare input parameters for process function
|
1558 |
+
# input_image = {"background": image, "layers": [image]}
|
1559 |
+
# original_image = np.array(image)
|
1560 |
+
# original_mask = None
|
1561 |
+
# prompt = caption
|
1562 |
+
# negative_prompt = "ugly, low quality"
|
1563 |
+
# control_strength = 1.0
|
1564 |
+
# seed = 648464818
|
1565 |
+
# randomize_seed = False
|
1566 |
+
# guidance_scale = 7.5
|
1567 |
+
# num_inference_steps = 50
|
1568 |
+
# num_samples = 1
|
1569 |
+
# blending = True
|
1570 |
+
# category = None
|
1571 |
+
# target_prompt = ""
|
1572 |
+
# resize_default = True
|
1573 |
+
# aspect_ratio = "Custom resolution"
|
1574 |
+
# invert_mask_state = False
|
1575 |
+
|
1576 |
+
# # Call process function and handle return values properly
|
1577 |
+
# process_results = process(
|
1578 |
+
# input_image,
|
1579 |
+
# original_image,
|
1580 |
+
# original_mask,
|
1581 |
+
# prompt,
|
1582 |
+
# negative_prompt,
|
1583 |
+
# control_strength,
|
1584 |
+
# seed,
|
1585 |
+
# randomize_seed,
|
1586 |
+
# guidance_scale,
|
1587 |
+
# num_inference_steps,
|
1588 |
+
# num_samples,
|
1589 |
+
# blending,
|
1590 |
+
# category,
|
1591 |
+
# target_prompt,
|
1592 |
+
# resize_default,
|
1593 |
+
# aspect_ratio,
|
1594 |
+
# invert_mask_state
|
1595 |
+
# )
|
1596 |
+
|
1597 |
+
# # Extract results safely
|
1598 |
+
# result_images = process_results[0]
|
1599 |
+
# mask_images = process_results[1]
|
1600 |
+
# masked_images = process_results[2]
|
1601 |
+
|
1602 |
+
# # Ensure we have valid images to save
|
1603 |
+
# if not result_images or not mask_images or not masked_images:
|
1604 |
+
# print(f"Warning: No output images generated for {base_name}")
|
1605 |
+
# continue
|
1606 |
+
|
1607 |
+
# # Save processed images
|
1608 |
+
# # Save edited image
|
1609 |
+
# edited_path = output_dirs["edited"] / f"{base_name}.png"
|
1610 |
+
# if isinstance(result_images, (list, tuple)):
|
1611 |
+
# result_images[0].save(edited_path)
|
1612 |
+
# else:
|
1613 |
+
# result_images.save(edited_path)
|
1614 |
+
|
1615 |
+
# # Save mask image
|
1616 |
+
# mask_path = output_dirs["mask"] / f"{base_name}_mask.png"
|
1617 |
+
# if isinstance(mask_images, (list, tuple)):
|
1618 |
+
# mask_images[0].save(mask_path)
|
1619 |
+
# else:
|
1620 |
+
# mask_images.save(mask_path)
|
1621 |
+
|
1622 |
+
# # Save masked image
|
1623 |
+
# masked_path = output_dirs["masked"] / f"{base_name}_masked.png"
|
1624 |
+
# if isinstance(masked_images, (list, tuple)):
|
1625 |
+
# masked_images[0].save(masked_path)
|
1626 |
+
# else:
|
1627 |
+
# masked_images.save(masked_path)
|
1628 |
+
|
1629 |
+
# # Store description
|
1630 |
+
# descriptions[base_name] = {
|
1631 |
+
# "original_caption": caption,
|
1632 |
+
# "blip_description": blip_desc
|
1633 |
+
# }
|
1634 |
+
|
1635 |
+
# print(f"Successfully processed {base_name}")
|
1636 |
+
|
1637 |
+
# except Exception as e:
|
1638 |
+
# print(f"Error processing image {base_name}: {e}")
|
1639 |
+
# continue
|
1640 |
+
|
1641 |
+
# # Save descriptions to JSON file
|
1642 |
+
# with open("/home/zt/data/BrushEdit/cirr/cirr_description_fix.json", 'w') as f:
|
1643 |
+
# json.dump(descriptions, f, indent=4)
|
1644 |
+
|
1645 |
+
# print("Processing completed!")
|
1646 |
+
|
1647 |
+
|
1648 |
+
def generate_blip_description(input_image):
|
1649 |
+
if input_image is None:
|
1650 |
+
return "", "Input image cannot be None"
|
1651 |
+
try:
|
1652 |
+
image_pil = input_image["background"].convert("RGB")
|
1653 |
+
except KeyError:
|
1654 |
+
return "", "Input image missing 'background' key"
|
1655 |
+
except AttributeError as e:
|
1656 |
+
return "", f"Invalid image object: {str(e)}"
|
1657 |
+
try:
|
1658 |
+
description = generate_caption(blip_processor, blip_model, image_pil, device)
|
1659 |
+
return description, description # 同时更新state和显示组件
|
1660 |
+
except Exception as e:
|
1661 |
+
return "", f"Caption generation failed: {str(e)}"
|
1662 |
+
|
1663 |
+
|
1664 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
1665 |
+
global vlm_model, vlm_processor
|
1666 |
+
if vlm_model is not None:
|
1667 |
+
del vlm_model
|
1668 |
+
torch.cuda.empty_cache()
|
1669 |
+
try:
|
1670 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
|
1671 |
+
vlm_processor = ""
|
1672 |
+
response = vlm_model.chat.completions.create(
|
1673 |
+
model="deepseek-chat",
|
1674 |
+
messages=[
|
1675 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1676 |
+
{"role": "user", "content": "Hello."}
|
1677 |
+
]
|
1678 |
+
)
|
1679 |
+
response_str = response.choices[0].message.content
|
1680 |
+
|
1681 |
+
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
|
1682 |
+
except Exception as e:
|
1683 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
1684 |
+
|
1685 |
+
|
1686 |
+
def verify_deepseek_api():
|
1687 |
+
try:
|
1688 |
+
response = llm_model.chat.completions.create(
|
1689 |
+
model="deepseek-chat",
|
1690 |
+
messages=[
|
1691 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1692 |
+
{"role": "user", "content": "Hello."}
|
1693 |
+
]
|
1694 |
+
)
|
1695 |
+
response_str = response.choices[0].message.content
|
1696 |
+
|
1697 |
+
return True, "Success. " + response_str
|
1698 |
+
|
1699 |
+
except Exception as e:
|
1700 |
+
return False, "Invalid DeepSeek API Key"
|
1701 |
+
|
1702 |
+
|
1703 |
+
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
|
1704 |
+
try:
|
1705 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
1706 |
+
response = llm_model.chat.completions.create(
|
1707 |
+
model="deepseek-chat",
|
1708 |
+
messages=messages
|
1709 |
+
)
|
1710 |
+
response_str = response.choices[0].message.content
|
1711 |
+
return response_str
|
1712 |
+
except Exception as e:
|
1713 |
+
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1714 |
+
|
1715 |
+
|
1716 |
+
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
|
1717 |
+
try:
|
1718 |
+
messages = create_decomposed_query_messages_deepseek(integrated_query)
|
1719 |
+
response = llm_model.chat.completions.create(
|
1720 |
+
model="deepseek-chat",
|
1721 |
+
messages=messages
|
1722 |
+
)
|
1723 |
+
response_str = response.choices[0].message.content
|
1724 |
+
return response_str
|
1725 |
+
except Exception as e:
|
1726 |
+
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1727 |
+
|
1728 |
+
|
1729 |
+
def enhance_description(blip_description, prompt):
|
1730 |
+
try:
|
1731 |
+
if not prompt or not blip_description:
|
1732 |
+
print("Empty prompt or blip_description detected")
|
1733 |
+
return "", ""
|
1734 |
+
|
1735 |
+
print(f"Enhancing with prompt: {prompt}")
|
1736 |
+
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
|
1737 |
+
return enhanced_description, enhanced_description
|
1738 |
+
|
1739 |
+
except Exception as e:
|
1740 |
+
print(f"Enhancement failed: {str(e)}")
|
1741 |
+
return "Error occurred", "Error occurred"
|
1742 |
+
|
1743 |
+
|
1744 |
+
def decompose_description(enhanced_description):
|
1745 |
+
try:
|
1746 |
+
if not enhanced_description:
|
1747 |
+
print("Empty enhanced_description detected")
|
1748 |
+
return "", ""
|
1749 |
+
|
1750 |
+
print(f"Decomposing the enhanced description: {enhanced_description}")
|
1751 |
+
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
|
1752 |
+
return decomposed_description, decomposed_description
|
1753 |
+
|
1754 |
+
except Exception as e:
|
1755 |
+
print(f"Decomposition failed: {str(e)}")
|
1756 |
+
return "Error occurred", "Error occurred"
|
1757 |
+
|
1758 |
+
|
1759 |
+
@torch.no_grad()
|
1760 |
+
def mix_and_search(enhanced_text: str, gallery_images: list):
|
1761 |
+
# 获取最新生成的图像元组
|
1762 |
+
latest_item = gallery_images[-1] if gallery_images else None
|
1763 |
+
|
1764 |
+
# 初始化特征列表
|
1765 |
+
features = []
|
1766 |
+
|
1767 |
+
# 图像特征提取
|
1768 |
+
if latest_item and isinstance(latest_item, tuple):
|
1769 |
+
try:
|
1770 |
+
image_path = latest_item[0]
|
1771 |
+
pil_image = Image.open(image_path).convert("RGB")
|
1772 |
+
|
1773 |
+
# 使用 CLIPProcessor 处理图像
|
1774 |
+
image_inputs = clip_processor(
|
1775 |
+
images=pil_image,
|
1776 |
+
return_tensors="pt"
|
1777 |
+
).to(device)
|
1778 |
+
|
1779 |
+
image_features = clip_model.get_image_features(**image_inputs)
|
1780 |
+
features.append(F.normalize(image_features, dim=-1))
|
1781 |
+
except Exception as e:
|
1782 |
+
print(f"图像处理失败: {str(e)}")
|
1783 |
+
|
1784 |
+
# 文本特征提取
|
1785 |
+
if enhanced_text.strip():
|
1786 |
+
text_inputs = clip_processor(
|
1787 |
+
text=enhanced_text,
|
1788 |
+
return_tensors="pt",
|
1789 |
+
padding=True,
|
1790 |
+
truncation=True
|
1791 |
+
).to(device)
|
1792 |
+
|
1793 |
+
text_features = clip_model.get_text_features(**text_inputs)
|
1794 |
+
features.append(F.normalize(text_features, dim=-1))
|
1795 |
+
|
1796 |
+
if not features:
|
1797 |
+
return []
|
1798 |
+
|
1799 |
+
|
1800 |
+
# 特征融合与检索
|
1801 |
+
mixed = sum(features) / len(features)
|
1802 |
+
mixed = F.normalize(mixed, dim=-1)
|
1803 |
+
|
1804 |
+
# 加载Faiss索引和图片路径映射
|
1805 |
+
# index_path = "/home/zt/data/open-images/train/knn.index"
|
1806 |
+
# input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
|
1807 |
+
# base_image_dir = Path("/home/zt/data/open-images/train/")
|
1808 |
+
|
1809 |
+
index_path = "/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index"
|
1810 |
+
input_data_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata")
|
1811 |
+
base_image_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/")
|
1812 |
+
|
1813 |
+
# 按文件名中的数字排序并直接读取parquet文件
|
1814 |
+
parquet_files = sorted(
|
1815 |
+
input_data_dir.glob('*.parquet'),
|
1816 |
+
key=lambda x: int(x.stem.split("_")[-1])
|
1817 |
+
)
|
1818 |
+
|
1819 |
+
# 合并所有parquet数据
|
1820 |
+
dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
|
1821 |
+
df = pd.concat(dfs, ignore_index=True)
|
1822 |
+
image_paths = df["image_path"].tolist()
|
1823 |
+
|
1824 |
+
# 读取Faiss索引
|
1825 |
+
index = faiss.read_index(index_path)
|
1826 |
+
assert mixed.shape[1] == index.d, "特征维度不匹配"
|
1827 |
+
|
1828 |
+
# 执行检索
|
1829 |
+
mixed = mixed.cpu().detach().numpy().astype('float32')
|
1830 |
+
distances, indices = index.search(mixed, 50)
|
1831 |
+
|
1832 |
+
# 获取并验证图片路径
|
1833 |
+
retrieved_images = []
|
1834 |
+
for idx in indices[0]:
|
1835 |
+
if 0 <= idx < len(image_paths):
|
1836 |
+
img_path = base_image_dir / image_paths[idx]
|
1837 |
+
try:
|
1838 |
+
if img_path.exists():
|
1839 |
+
retrieved_images.append(Image.open(img_path).convert("RGB"))
|
1840 |
+
else:
|
1841 |
+
print(f"警告:文件缺失 {img_path}")
|
1842 |
+
except Exception as e:
|
1843 |
+
print(f"图片加载失败: {str(e)}")
|
1844 |
+
|
1845 |
+
return retrieved_images if retrieved_images else ([])
|
1846 |
+
|
1847 |
+
|
1848 |
+
|
1849 |
+
if __name__ == "__main__":
|
1850 |
+
process_cirr_images()
|
1851 |
+
|
1852 |
+
|
1853 |
+
# block = gr.Blocks(
|
1854 |
+
# theme=gr.themes.Soft(
|
1855 |
+
# radius_size=gr.themes.sizes.radius_none,
|
1856 |
+
# text_size=gr.themes.sizes.text_md
|
1857 |
+
# )
|
1858 |
+
# )
|
1859 |
+
|
1860 |
+
# with block as demo:
|
1861 |
+
# with gr.Row():
|
1862 |
+
# with gr.Column():
|
1863 |
+
# gr.HTML(head)
|
1864 |
+
# gr.Markdown(descriptions)
|
1865 |
+
# with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1866 |
+
# with gr.Row(equal_height=True):
|
1867 |
+
# gr.Markdown(instructions)
|
1868 |
+
|
1869 |
+
# original_image = gr.State(value=None)
|
1870 |
+
# original_mask = gr.State(value=None)
|
1871 |
+
# category = gr.State(value=None)
|
1872 |
+
# status = gr.State(value=None)
|
1873 |
+
# invert_mask_state = gr.State(value=False)
|
1874 |
+
# example_change_times = gr.State(value=0)
|
1875 |
+
# deepseek_verified = gr.State(value=False)
|
1876 |
+
# blip_description = gr.State(value="")
|
1877 |
+
# enhanced_description = gr.State(value="")
|
1878 |
+
# decomposed_description = gr.State(value="")
|
1879 |
+
|
1880 |
+
# with gr.Row():
|
1881 |
+
# with gr.Column():
|
1882 |
+
# with gr.Group():
|
1883 |
+
# input_image = gr.ImageEditor(
|
1884 |
+
# label="参考图像",
|
1885 |
+
# type="pil",
|
1886 |
+
# brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1887 |
+
# layers = False,
|
1888 |
+
# interactive=True,
|
1889 |
+
# # height=1024,
|
1890 |
+
# height=412,
|
1891 |
+
# sources=["upload"],
|
1892 |
+
# placeholder="🫧 点击此处或下面的图标上传图像 🫧",
|
1893 |
+
# )
|
1894 |
+
# prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
|
1895 |
+
|
1896 |
+
# with gr.Group():
|
1897 |
+
# mask_button = gr.Button("💎 掩膜生成")
|
1898 |
+
# with gr.Row():
|
1899 |
+
# invert_mask_button = gr.Button("👐 掩膜翻转")
|
1900 |
+
# random_mask_button = gr.Button("⭕️ 随机掩膜")
|
1901 |
+
# with gr.Row():
|
1902 |
+
# masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
|
1903 |
+
# mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
|
1904 |
+
|
1905 |
+
|
1906 |
+
# with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
|
1907 |
+
# dilation_size = gr.Slider(
|
1908 |
+
# label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
|
1909 |
+
# )
|
1910 |
+
# with gr.Row():
|
1911 |
+
# dilation_mask_button = gr.Button("放大掩膜")
|
1912 |
+
# erosion_mask_button = gr.Button("缩小掩膜")
|
1913 |
+
|
1914 |
+
# moving_pixels = gr.Slider(
|
1915 |
+
# label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1916 |
+
# )
|
1917 |
+
# with gr.Row():
|
1918 |
+
# move_left_button = gr.Button("左移")
|
1919 |
+
# move_right_button = gr.Button("右移")
|
1920 |
+
# with gr.Row():
|
1921 |
+
# move_up_button = gr.Button("上移")
|
1922 |
+
# move_down_button = gr.Button("下移")
|
1923 |
+
|
1924 |
+
|
1925 |
+
|
1926 |
+
# with gr.Column():
|
1927 |
+
# with gr.Row():
|
1928 |
+
# deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=2, container=False)
|
1929 |
+
# verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
|
1930 |
+
# blip_output = gr.Textbox(label="1. 原图描述(BLIP生成)", placeholder="🖼️ 上传图片后自动生成图片描述 🖼️", lines=2, interactive=True)
|
1931 |
+
# with gr.Row():
|
1932 |
+
# enhanced_output = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击右侧按钮生成增强描述 🚀")
|
1933 |
+
# enhance_button = gr.Button("✨ 智能整合")
|
1934 |
+
|
1935 |
+
# with gr.Row():
|
1936 |
+
# decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述 📝")
|
1937 |
+
# decompose_button = gr.Button("🔧 结构分解")
|
1938 |
+
|
1939 |
+
|
1940 |
+
|
1941 |
+
# with gr.Group():
|
1942 |
+
# run_button = gr.Button("💫 图像编辑")
|
1943 |
+
# result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
|
1944 |
+
|
1945 |
+
# with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
|
1946 |
+
# vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1947 |
+
|
1948 |
+
# with gr.Group():
|
1949 |
+
# with gr.Row():
|
1950 |
+
# # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1951 |
+
# GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1952 |
+
# GPT4o_KEY_submit = gr.Button("🔑 验证密钥")
|
1953 |
+
|
1954 |
+
# aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1955 |
+
# resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
|
1956 |
+
# base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1957 |
+
# negative_prompt = gr.Text(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
|
1958 |
+
# control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
|
1959 |
+
# with gr.Group():
|
1960 |
+
# seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
|
1961 |
+
# randomize_seed = gr.Checkbox(label="随机种子", value=False)
|
1962 |
+
# blending = gr.Checkbox(label="混合模式", value=True)
|
1963 |
+
# num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
|
1964 |
+
# with gr.Group():
|
1965 |
+
# with gr.Row():
|
1966 |
+
# guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
|
1967 |
+
# num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
|
1968 |
+
# target_prompt = gr.Text(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
|
1969 |
+
|
1970 |
+
|
1971 |
+
|
1972 |
+
# init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1973 |
+
# example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1974 |
+
|
1975 |
+
# with gr.Row():
|
1976 |
+
# reset_button = gr.Button("Reset")
|
1977 |
+
# retrieve_button = gr.Button("🔍 开始检索")
|
1978 |
+
|
1979 |
+
# with gr.Row():
|
1980 |
+
# retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=660)
|
1981 |
+
|
1982 |
+
|
1983 |
+
# with gr.Row():
|
1984 |
+
# example = gr.Examples(
|
1985 |
+
# label="Quick Example",
|
1986 |
+
# examples=EXAMPLES,
|
1987 |
+
# inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1988 |
+
# examples_per_page=10,
|
1989 |
+
# cache_examples=False,
|
1990 |
+
# )
|
1991 |
+
|
1992 |
+
|
1993 |
+
# with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1994 |
+
# with gr.Row(equal_height=True):
|
1995 |
+
# gr.Markdown(tips)
|
1996 |
+
|
1997 |
+
# with gr.Row():
|
1998 |
+
# gr.Markdown(citation)
|
1999 |
+
|
2000 |
+
# ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
2001 |
+
# ## And we need to solve the conflict between the upload and change example functions.
|
2002 |
+
# input_image.upload(
|
2003 |
+
# init_img,
|
2004 |
+
# [input_image, init_type, prompt, aspect_ratio, example_change_times],
|
2005 |
+
# [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
2006 |
+
|
2007 |
+
# )
|
2008 |
+
# example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
2009 |
+
|
2010 |
+
|
2011 |
+
# ## vlm and base model dropdown
|
2012 |
+
# vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
2013 |
+
# base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
2014 |
+
|
2015 |
+
# GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
2016 |
+
|
2017 |
+
|
2018 |
+
# ips=[input_image,
|
2019 |
+
# original_image,
|
2020 |
+
# original_mask,
|
2021 |
+
# prompt,
|
2022 |
+
# negative_prompt,
|
2023 |
+
# control_strength,
|
2024 |
+
# seed,
|
2025 |
+
# randomize_seed,
|
2026 |
+
# guidance_scale,
|
2027 |
+
# num_inference_steps,
|
2028 |
+
# num_samples,
|
2029 |
+
# blending,
|
2030 |
+
# category,
|
2031 |
+
# target_prompt,
|
2032 |
+
# resize_default,
|
2033 |
+
# aspect_ratio,
|
2034 |
+
# invert_mask_state]
|
2035 |
+
|
2036 |
+
# ## run brushedit
|
2037 |
+
# run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
2038 |
+
|
2039 |
+
|
2040 |
+
# ## mask func
|
2041 |
+
# mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
2042 |
+
# random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
2043 |
+
# dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
2044 |
+
# erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
2045 |
+
# invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
2046 |
+
|
2047 |
+
# ## reset func
|
2048 |
+
# reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
2049 |
+
|
2050 |
+
# input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
|
2051 |
+
# verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
|
2052 |
+
# enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
|
2053 |
+
# decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
|
2054 |
+
# retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
|
2055 |
+
|
2056 |
+
# demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
2057 |
+
|
2058 |
+
|
brushedit_app_new_aftermeeting_nocirr.py
ADDED
@@ -0,0 +1,1809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
import pandas as pd
|
10 |
+
import concurrent.futures
|
11 |
+
import faiss
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
import torch.nn.functional as F # 新增此行
|
17 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
18 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
19 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
20 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
21 |
+
|
22 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
23 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
24 |
+
from diffusers.image_processor import VaeImageProcessor
|
25 |
+
|
26 |
+
|
27 |
+
from app.src.vlm_pipeline import (
|
28 |
+
vlm_response_editing_type,
|
29 |
+
vlm_response_object_wait_for_edit,
|
30 |
+
vlm_response_mask,
|
31 |
+
vlm_response_prompt_after_apply_instruction
|
32 |
+
)
|
33 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
34 |
+
from app.utils.utils import load_grounding_dino_model
|
35 |
+
|
36 |
+
from app.src.vlm_template import vlms_template
|
37 |
+
from app.src.base_model_template import base_models_template
|
38 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
39 |
+
|
40 |
+
from openai import OpenAI
|
41 |
+
base_openai_url = "https://api.deepseek.com/"
|
42 |
+
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
|
43 |
+
|
44 |
+
|
45 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
46 |
+
from transformers import CLIPProcessor, CLIPModel
|
47 |
+
|
48 |
+
from app.deepseek.instructions import (
|
49 |
+
create_apply_editing_messages_deepseek,
|
50 |
+
create_decomposed_query_messages_deepseek
|
51 |
+
)
|
52 |
+
from clip_retrieval.clip_client import ClipClient
|
53 |
+
|
54 |
+
#### Description ####
|
55 |
+
logo = r"""
|
56 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
57 |
+
"""
|
58 |
+
head = r"""
|
59 |
+
<div style="text-align: center;">
|
60 |
+
<h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
|
61 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
62 |
+
<a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
63 |
+
<a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
64 |
+
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
65 |
+
|
66 |
+
</div>
|
67 |
+
</br>
|
68 |
+
</div>
|
69 |
+
"""
|
70 |
+
descriptions = r"""
|
71 |
+
Demo for ZS-CIR"""
|
72 |
+
|
73 |
+
instructions = r"""
|
74 |
+
Demo for ZS-CIR"""
|
75 |
+
|
76 |
+
tips = r"""
|
77 |
+
Demo for ZS-CIR
|
78 |
+
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
citation = r"""
|
84 |
+
Demo for ZS-CIR"""
|
85 |
+
|
86 |
+
# - - - - - examples - - - - - #
|
87 |
+
EXAMPLES = [
|
88 |
+
|
89 |
+
[
|
90 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
91 |
+
"add a magic hat on frog head.",
|
92 |
+
642087011,
|
93 |
+
"frog",
|
94 |
+
"frog",
|
95 |
+
True,
|
96 |
+
False,
|
97 |
+
"GPT4-o (Highly Recommended)"
|
98 |
+
],
|
99 |
+
[
|
100 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
101 |
+
"replace the background to ancient China.",
|
102 |
+
648464818,
|
103 |
+
"chinese_girl",
|
104 |
+
"chinese_girl",
|
105 |
+
True,
|
106 |
+
False,
|
107 |
+
"GPT4-o (Highly Recommended)"
|
108 |
+
],
|
109 |
+
[
|
110 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
111 |
+
"remove the deer.",
|
112 |
+
648464818,
|
113 |
+
"angel_christmas",
|
114 |
+
"angel_christmas",
|
115 |
+
False,
|
116 |
+
False,
|
117 |
+
"GPT4-o (Highly Recommended)"
|
118 |
+
],
|
119 |
+
[
|
120 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
121 |
+
"add a wreath on head.",
|
122 |
+
648464818,
|
123 |
+
"sunflower_girl",
|
124 |
+
"sunflower_girl",
|
125 |
+
True,
|
126 |
+
False,
|
127 |
+
"GPT4-o (Highly Recommended)"
|
128 |
+
],
|
129 |
+
[
|
130 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
131 |
+
"add a butterfly fairy.",
|
132 |
+
648464818,
|
133 |
+
"girl_on_sun",
|
134 |
+
"girl_on_sun",
|
135 |
+
True,
|
136 |
+
False,
|
137 |
+
"GPT4-o (Highly Recommended)"
|
138 |
+
],
|
139 |
+
[
|
140 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
141 |
+
"remove the christmas hat.",
|
142 |
+
642087011,
|
143 |
+
"spider_man_rm",
|
144 |
+
"spider_man_rm",
|
145 |
+
False,
|
146 |
+
False,
|
147 |
+
"GPT4-o (Highly Recommended)"
|
148 |
+
],
|
149 |
+
[
|
150 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
151 |
+
"remove the flower.",
|
152 |
+
642087011,
|
153 |
+
"anime_flower",
|
154 |
+
"anime_flower",
|
155 |
+
False,
|
156 |
+
False,
|
157 |
+
"GPT4-o (Highly Recommended)"
|
158 |
+
],
|
159 |
+
[
|
160 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
161 |
+
"replace the clothes to a delicated floral skirt.",
|
162 |
+
648464818,
|
163 |
+
"chenduling",
|
164 |
+
"chenduling",
|
165 |
+
True,
|
166 |
+
False,
|
167 |
+
"GPT4-o (Highly Recommended)"
|
168 |
+
],
|
169 |
+
[
|
170 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
171 |
+
"make the hedgehog in Italy.",
|
172 |
+
648464818,
|
173 |
+
"hedgehog_rp_bg",
|
174 |
+
"hedgehog_rp_bg",
|
175 |
+
True,
|
176 |
+
False,
|
177 |
+
"GPT4-o (Highly Recommended)"
|
178 |
+
],
|
179 |
+
|
180 |
+
]
|
181 |
+
|
182 |
+
INPUT_IMAGE_PATH = {
|
183 |
+
"frog": "./assets/frog/frog.jpeg",
|
184 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
185 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
186 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
187 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
188 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
189 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
190 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
191 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
192 |
+
}
|
193 |
+
MASK_IMAGE_PATH = {
|
194 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
195 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
196 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
197 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
198 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
199 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
200 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
201 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
202 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
203 |
+
}
|
204 |
+
MASKED_IMAGE_PATH = {
|
205 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
206 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
207 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
208 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
209 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
210 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
211 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
212 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
213 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
214 |
+
}
|
215 |
+
OUTPUT_IMAGE_PATH = {
|
216 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
217 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
218 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
219 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
220 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
221 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
222 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
223 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
224 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
225 |
+
}
|
226 |
+
|
227 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
228 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
229 |
+
|
230 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
231 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
232 |
+
|
233 |
+
|
234 |
+
BASE_MODELS = list(base_models_template.keys())
|
235 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
236 |
+
|
237 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
238 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
239 |
+
|
240 |
+
|
241 |
+
## init device
|
242 |
+
try:
|
243 |
+
if torch.cuda.is_available():
|
244 |
+
device = "cuda"
|
245 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
246 |
+
device = "mps"
|
247 |
+
else:
|
248 |
+
device = "cpu"
|
249 |
+
except:
|
250 |
+
device = "cpu"
|
251 |
+
|
252 |
+
# ## init torch dtype
|
253 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
254 |
+
# torch_dtype = torch.bfloat16
|
255 |
+
# else:
|
256 |
+
# torch_dtype = torch.float16
|
257 |
+
|
258 |
+
# if device == "mps":
|
259 |
+
# torch_dtype = torch.float16
|
260 |
+
|
261 |
+
torch_dtype = torch.float16
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
# download hf models
|
266 |
+
BrushEdit_path = "models/"
|
267 |
+
if not os.path.exists(BrushEdit_path):
|
268 |
+
BrushEdit_path = snapshot_download(
|
269 |
+
repo_id="TencentARC/BrushEdit",
|
270 |
+
local_dir=BrushEdit_path,
|
271 |
+
token=os.getenv("HF_TOKEN"),
|
272 |
+
)
|
273 |
+
|
274 |
+
## init default VLM
|
275 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
276 |
+
if vlm_processor != "" and vlm_model != "":
|
277 |
+
vlm_model.to(device)
|
278 |
+
else:
|
279 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
280 |
+
|
281 |
+
## init default LLM
|
282 |
+
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
|
283 |
+
|
284 |
+
## init base model
|
285 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
286 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
287 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
288 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
289 |
+
|
290 |
+
|
291 |
+
# input brushnetX ckpt path
|
292 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
293 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
294 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
295 |
+
)
|
296 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
297 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
298 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
299 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
300 |
+
pipe.enable_model_cpu_offload()
|
301 |
+
|
302 |
+
|
303 |
+
## init SAM
|
304 |
+
sam = build_sam(checkpoint=sam_path)
|
305 |
+
sam.to(device=device)
|
306 |
+
sam_predictor = SamPredictor(sam)
|
307 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
308 |
+
|
309 |
+
## init groundingdino_model
|
310 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
311 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
312 |
+
|
313 |
+
## Ordinary function
|
314 |
+
def crop_and_resize(image: Image.Image,
|
315 |
+
target_width: int,
|
316 |
+
target_height: int) -> Image.Image:
|
317 |
+
"""
|
318 |
+
Crops and resizes an image while preserving the aspect ratio.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
322 |
+
target_width (int): Target width of the output image.
|
323 |
+
target_height (int): Target height of the output image.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
Image.Image: Cropped and resized image.
|
327 |
+
"""
|
328 |
+
# Original dimensions
|
329 |
+
original_width, original_height = image.size
|
330 |
+
original_aspect = original_width / original_height
|
331 |
+
target_aspect = target_width / target_height
|
332 |
+
|
333 |
+
# Calculate crop box to maintain aspect ratio
|
334 |
+
if original_aspect > target_aspect:
|
335 |
+
# Crop horizontally
|
336 |
+
new_width = int(original_height * target_aspect)
|
337 |
+
new_height = original_height
|
338 |
+
left = (original_width - new_width) / 2
|
339 |
+
top = 0
|
340 |
+
right = left + new_width
|
341 |
+
bottom = original_height
|
342 |
+
else:
|
343 |
+
# Crop vertically
|
344 |
+
new_width = original_width
|
345 |
+
new_height = int(original_width / target_aspect)
|
346 |
+
left = 0
|
347 |
+
top = (original_height - new_height) / 2
|
348 |
+
right = original_width
|
349 |
+
bottom = top + new_height
|
350 |
+
|
351 |
+
# Crop and resize
|
352 |
+
cropped_image = image.crop((left, top, right, bottom))
|
353 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
354 |
+
return resized_image
|
355 |
+
|
356 |
+
|
357 |
+
## Ordinary function
|
358 |
+
def resize(image: Image.Image,
|
359 |
+
target_width: int,
|
360 |
+
target_height: int) -> Image.Image:
|
361 |
+
"""
|
362 |
+
Crops and resizes an image while preserving the aspect ratio.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
366 |
+
target_width (int): Target width of the output image.
|
367 |
+
target_height (int): Target height of the output image.
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Image.Image: Cropped and resized image.
|
371 |
+
"""
|
372 |
+
# Original dimensions
|
373 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
374 |
+
return resized_image
|
375 |
+
|
376 |
+
|
377 |
+
def move_mask_func(mask, direction, units):
|
378 |
+
binary_mask = mask.squeeze()>0
|
379 |
+
rows, cols = binary_mask.shape
|
380 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
381 |
+
|
382 |
+
if direction == 'down':
|
383 |
+
# move down
|
384 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
385 |
+
|
386 |
+
elif direction == 'up':
|
387 |
+
# move up
|
388 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
389 |
+
|
390 |
+
elif direction == 'right':
|
391 |
+
# move left
|
392 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
393 |
+
|
394 |
+
elif direction == 'left':
|
395 |
+
# move right
|
396 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
397 |
+
|
398 |
+
return moved_mask
|
399 |
+
|
400 |
+
|
401 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
402 |
+
# Randomly select the size of dilation
|
403 |
+
binary_mask = mask.squeeze()>0
|
404 |
+
|
405 |
+
if dilation_type == 'square_dilation':
|
406 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
407 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
408 |
+
elif dilation_type == 'square_erosion':
|
409 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
410 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
411 |
+
elif dilation_type == 'bounding_box':
|
412 |
+
# find the most left top and left bottom point
|
413 |
+
rows, cols = np.where(binary_mask)
|
414 |
+
if len(rows) == 0 or len(cols) == 0:
|
415 |
+
return mask # return original mask if no valid points
|
416 |
+
|
417 |
+
min_row = np.min(rows)
|
418 |
+
max_row = np.max(rows)
|
419 |
+
min_col = np.min(cols)
|
420 |
+
max_col = np.max(cols)
|
421 |
+
|
422 |
+
# create a bounding box
|
423 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
424 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
425 |
+
|
426 |
+
elif dilation_type == 'bounding_ellipse':
|
427 |
+
# find the most left top and left bottom point
|
428 |
+
rows, cols = np.where(binary_mask)
|
429 |
+
if len(rows) == 0 or len(cols) == 0:
|
430 |
+
return mask # return original mask if no valid points
|
431 |
+
|
432 |
+
min_row = np.min(rows)
|
433 |
+
max_row = np.max(rows)
|
434 |
+
min_col = np.min(cols)
|
435 |
+
max_col = np.max(cols)
|
436 |
+
|
437 |
+
# calculate the center and axis length of the ellipse
|
438 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
439 |
+
a = (max_col - min_col) // 2 # half long axis
|
440 |
+
b = (max_row - min_row) // 2 # half short axis
|
441 |
+
|
442 |
+
# create a bounding ellipse
|
443 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
444 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
445 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
446 |
+
dilated_mask[ellipse_mask] = True
|
447 |
+
else:
|
448 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
449 |
+
|
450 |
+
# use binary dilation
|
451 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
452 |
+
return dilated_mask
|
453 |
+
|
454 |
+
|
455 |
+
## Gradio component function
|
456 |
+
def update_vlm_model(vlm_name):
|
457 |
+
global vlm_model, vlm_processor
|
458 |
+
if vlm_model is not None:
|
459 |
+
del vlm_model
|
460 |
+
torch.cuda.empty_cache()
|
461 |
+
|
462 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
463 |
+
|
464 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
465 |
+
if vlm_type == "llava-next":
|
466 |
+
if vlm_processor != "" and vlm_model != "":
|
467 |
+
vlm_model.to(device)
|
468 |
+
return vlm_model_dropdown
|
469 |
+
else:
|
470 |
+
if os.path.exists(vlm_local_path):
|
471 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
472 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
473 |
+
else:
|
474 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
475 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
476 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
477 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
478 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
479 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
480 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
481 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
482 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
483 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
484 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
485 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
486 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
487 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
488 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
489 |
+
elif vlm_type == "qwen2-vl":
|
490 |
+
if vlm_processor != "" and vlm_model != "":
|
491 |
+
vlm_model.to(device)
|
492 |
+
return vlm_model_dropdown
|
493 |
+
else:
|
494 |
+
if os.path.exists(vlm_local_path):
|
495 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
496 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
497 |
+
else:
|
498 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
499 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
500 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
501 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
502 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
503 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
504 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
505 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
506 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
507 |
+
elif vlm_type == "openai":
|
508 |
+
pass
|
509 |
+
return "success"
|
510 |
+
|
511 |
+
|
512 |
+
def update_base_model(base_model_name):
|
513 |
+
global pipe
|
514 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
515 |
+
if pipe is not None:
|
516 |
+
del pipe
|
517 |
+
torch.cuda.empty_cache()
|
518 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
519 |
+
if pipe != "":
|
520 |
+
pipe.to(device)
|
521 |
+
else:
|
522 |
+
if os.path.exists(base_model_path):
|
523 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
524 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
525 |
+
)
|
526 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
527 |
+
pipe.enable_model_cpu_offload()
|
528 |
+
else:
|
529 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
530 |
+
return "success"
|
531 |
+
|
532 |
+
|
533 |
+
def process_random_mask(input_image,
|
534 |
+
original_image,
|
535 |
+
original_mask,
|
536 |
+
resize_default,
|
537 |
+
aspect_ratio_name,
|
538 |
+
):
|
539 |
+
|
540 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
541 |
+
input_mask = np.asarray(alpha_mask)
|
542 |
+
|
543 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
544 |
+
if output_w == "" or output_h == "":
|
545 |
+
output_h, output_w = original_image.shape[:2]
|
546 |
+
if resize_default:
|
547 |
+
short_side = min(output_w, output_h)
|
548 |
+
scale_ratio = 640 / short_side
|
549 |
+
output_w = int(output_w * scale_ratio)
|
550 |
+
output_h = int(output_h * scale_ratio)
|
551 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
552 |
+
original_image = np.array(original_image)
|
553 |
+
if input_mask is not None:
|
554 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
555 |
+
input_mask = np.array(input_mask)
|
556 |
+
if original_mask is not None:
|
557 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
558 |
+
original_mask = np.array(original_mask)
|
559 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
560 |
+
else:
|
561 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
562 |
+
pass
|
563 |
+
else:
|
564 |
+
if resize_default:
|
565 |
+
short_side = min(output_w, output_h)
|
566 |
+
scale_ratio = 640 / short_side
|
567 |
+
output_w = int(output_w * scale_ratio)
|
568 |
+
output_h = int(output_h * scale_ratio)
|
569 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
570 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
571 |
+
original_image = np.array(original_image)
|
572 |
+
if input_mask is not None:
|
573 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
574 |
+
input_mask = np.array(input_mask)
|
575 |
+
if original_mask is not None:
|
576 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
577 |
+
original_mask = np.array(original_mask)
|
578 |
+
|
579 |
+
|
580 |
+
if input_mask.max() == 0:
|
581 |
+
original_mask = original_mask
|
582 |
+
else:
|
583 |
+
original_mask = input_mask
|
584 |
+
|
585 |
+
if original_mask is None:
|
586 |
+
raise gr.Error('Please generate mask first')
|
587 |
+
|
588 |
+
if original_mask.ndim == 2:
|
589 |
+
original_mask = original_mask[:,:,None]
|
590 |
+
|
591 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
592 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
593 |
+
|
594 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
595 |
+
|
596 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
597 |
+
masked_image = masked_image.astype(original_image.dtype)
|
598 |
+
masked_image = Image.fromarray(masked_image)
|
599 |
+
|
600 |
+
|
601 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
602 |
+
|
603 |
+
|
604 |
+
def process_dilation_mask(input_image,
|
605 |
+
original_image,
|
606 |
+
original_mask,
|
607 |
+
resize_default,
|
608 |
+
aspect_ratio_name,
|
609 |
+
dilation_size=20):
|
610 |
+
|
611 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
612 |
+
input_mask = np.asarray(alpha_mask)
|
613 |
+
|
614 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
615 |
+
if output_w == "" or output_h == "":
|
616 |
+
output_h, output_w = original_image.shape[:2]
|
617 |
+
if resize_default:
|
618 |
+
short_side = min(output_w, output_h)
|
619 |
+
scale_ratio = 640 / short_side
|
620 |
+
output_w = int(output_w * scale_ratio)
|
621 |
+
output_h = int(output_h * scale_ratio)
|
622 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
623 |
+
original_image = np.array(original_image)
|
624 |
+
if input_mask is not None:
|
625 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
626 |
+
input_mask = np.array(input_mask)
|
627 |
+
if original_mask is not None:
|
628 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
629 |
+
original_mask = np.array(original_mask)
|
630 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
631 |
+
else:
|
632 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
633 |
+
pass
|
634 |
+
else:
|
635 |
+
if resize_default:
|
636 |
+
short_side = min(output_w, output_h)
|
637 |
+
scale_ratio = 640 / short_side
|
638 |
+
output_w = int(output_w * scale_ratio)
|
639 |
+
output_h = int(output_h * scale_ratio)
|
640 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
641 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
642 |
+
original_image = np.array(original_image)
|
643 |
+
if input_mask is not None:
|
644 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
645 |
+
input_mask = np.array(input_mask)
|
646 |
+
if original_mask is not None:
|
647 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
648 |
+
original_mask = np.array(original_mask)
|
649 |
+
|
650 |
+
if input_mask.max() == 0:
|
651 |
+
original_mask = original_mask
|
652 |
+
else:
|
653 |
+
original_mask = input_mask
|
654 |
+
|
655 |
+
if original_mask is None:
|
656 |
+
raise gr.Error('Please generate mask first')
|
657 |
+
|
658 |
+
if original_mask.ndim == 2:
|
659 |
+
original_mask = original_mask[:,:,None]
|
660 |
+
|
661 |
+
dilation_type = np.random.choice(['square_dilation'])
|
662 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
663 |
+
|
664 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
665 |
+
|
666 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
667 |
+
masked_image = masked_image.astype(original_image.dtype)
|
668 |
+
masked_image = Image.fromarray(masked_image)
|
669 |
+
|
670 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
671 |
+
|
672 |
+
|
673 |
+
def process_erosion_mask(input_image,
|
674 |
+
original_image,
|
675 |
+
original_mask,
|
676 |
+
resize_default,
|
677 |
+
aspect_ratio_name,
|
678 |
+
dilation_size=20):
|
679 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
680 |
+
input_mask = np.asarray(alpha_mask)
|
681 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
682 |
+
if output_w == "" or output_h == "":
|
683 |
+
output_h, output_w = original_image.shape[:2]
|
684 |
+
if resize_default:
|
685 |
+
short_side = min(output_w, output_h)
|
686 |
+
scale_ratio = 640 / short_side
|
687 |
+
output_w = int(output_w * scale_ratio)
|
688 |
+
output_h = int(output_h * scale_ratio)
|
689 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
690 |
+
original_image = np.array(original_image)
|
691 |
+
if input_mask is not None:
|
692 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
693 |
+
input_mask = np.array(input_mask)
|
694 |
+
if original_mask is not None:
|
695 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
696 |
+
original_mask = np.array(original_mask)
|
697 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
698 |
+
else:
|
699 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
700 |
+
pass
|
701 |
+
else:
|
702 |
+
if resize_default:
|
703 |
+
short_side = min(output_w, output_h)
|
704 |
+
scale_ratio = 640 / short_side
|
705 |
+
output_w = int(output_w * scale_ratio)
|
706 |
+
output_h = int(output_h * scale_ratio)
|
707 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
708 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
709 |
+
original_image = np.array(original_image)
|
710 |
+
if input_mask is not None:
|
711 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
712 |
+
input_mask = np.array(input_mask)
|
713 |
+
if original_mask is not None:
|
714 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
715 |
+
original_mask = np.array(original_mask)
|
716 |
+
|
717 |
+
if input_mask.max() == 0:
|
718 |
+
original_mask = original_mask
|
719 |
+
else:
|
720 |
+
original_mask = input_mask
|
721 |
+
|
722 |
+
if original_mask is None:
|
723 |
+
raise gr.Error('Please generate mask first')
|
724 |
+
|
725 |
+
if original_mask.ndim == 2:
|
726 |
+
original_mask = original_mask[:,:,None]
|
727 |
+
|
728 |
+
dilation_type = np.random.choice(['square_erosion'])
|
729 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
730 |
+
|
731 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
732 |
+
|
733 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
734 |
+
masked_image = masked_image.astype(original_image.dtype)
|
735 |
+
masked_image = Image.fromarray(masked_image)
|
736 |
+
|
737 |
+
|
738 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
739 |
+
|
740 |
+
|
741 |
+
def move_mask_left(input_image,
|
742 |
+
original_image,
|
743 |
+
original_mask,
|
744 |
+
moving_pixels,
|
745 |
+
resize_default,
|
746 |
+
aspect_ratio_name):
|
747 |
+
|
748 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
749 |
+
input_mask = np.asarray(alpha_mask)
|
750 |
+
|
751 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
752 |
+
if output_w == "" or output_h == "":
|
753 |
+
output_h, output_w = original_image.shape[:2]
|
754 |
+
if resize_default:
|
755 |
+
short_side = min(output_w, output_h)
|
756 |
+
scale_ratio = 640 / short_side
|
757 |
+
output_w = int(output_w * scale_ratio)
|
758 |
+
output_h = int(output_h * scale_ratio)
|
759 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
760 |
+
original_image = np.array(original_image)
|
761 |
+
if input_mask is not None:
|
762 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
763 |
+
input_mask = np.array(input_mask)
|
764 |
+
if original_mask is not None:
|
765 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
766 |
+
original_mask = np.array(original_mask)
|
767 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
768 |
+
else:
|
769 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
770 |
+
pass
|
771 |
+
else:
|
772 |
+
if resize_default:
|
773 |
+
short_side = min(output_w, output_h)
|
774 |
+
scale_ratio = 640 / short_side
|
775 |
+
output_w = int(output_w * scale_ratio)
|
776 |
+
output_h = int(output_h * scale_ratio)
|
777 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
778 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
779 |
+
original_image = np.array(original_image)
|
780 |
+
if input_mask is not None:
|
781 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
782 |
+
input_mask = np.array(input_mask)
|
783 |
+
if original_mask is not None:
|
784 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
785 |
+
original_mask = np.array(original_mask)
|
786 |
+
|
787 |
+
if input_mask.max() == 0:
|
788 |
+
original_mask = original_mask
|
789 |
+
else:
|
790 |
+
original_mask = input_mask
|
791 |
+
|
792 |
+
if original_mask is None:
|
793 |
+
raise gr.Error('Please generate mask first')
|
794 |
+
|
795 |
+
if original_mask.ndim == 2:
|
796 |
+
original_mask = original_mask[:,:,None]
|
797 |
+
|
798 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
799 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
800 |
+
|
801 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
802 |
+
masked_image = masked_image.astype(original_image.dtype)
|
803 |
+
masked_image = Image.fromarray(masked_image)
|
804 |
+
|
805 |
+
if moved_mask.max() <= 1:
|
806 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
807 |
+
original_mask = moved_mask
|
808 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
809 |
+
|
810 |
+
|
811 |
+
def move_mask_right(input_image,
|
812 |
+
original_image,
|
813 |
+
original_mask,
|
814 |
+
moving_pixels,
|
815 |
+
resize_default,
|
816 |
+
aspect_ratio_name):
|
817 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
818 |
+
input_mask = np.asarray(alpha_mask)
|
819 |
+
|
820 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
821 |
+
if output_w == "" or output_h == "":
|
822 |
+
output_h, output_w = original_image.shape[:2]
|
823 |
+
if resize_default:
|
824 |
+
short_side = min(output_w, output_h)
|
825 |
+
scale_ratio = 640 / short_side
|
826 |
+
output_w = int(output_w * scale_ratio)
|
827 |
+
output_h = int(output_h * scale_ratio)
|
828 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
829 |
+
original_image = np.array(original_image)
|
830 |
+
if input_mask is not None:
|
831 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
832 |
+
input_mask = np.array(input_mask)
|
833 |
+
if original_mask is not None:
|
834 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
835 |
+
original_mask = np.array(original_mask)
|
836 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
837 |
+
else:
|
838 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
839 |
+
pass
|
840 |
+
else:
|
841 |
+
if resize_default:
|
842 |
+
short_side = min(output_w, output_h)
|
843 |
+
scale_ratio = 640 / short_side
|
844 |
+
output_w = int(output_w * scale_ratio)
|
845 |
+
output_h = int(output_h * scale_ratio)
|
846 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
847 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
848 |
+
original_image = np.array(original_image)
|
849 |
+
if input_mask is not None:
|
850 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
851 |
+
input_mask = np.array(input_mask)
|
852 |
+
if original_mask is not None:
|
853 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
854 |
+
original_mask = np.array(original_mask)
|
855 |
+
|
856 |
+
if input_mask.max() == 0:
|
857 |
+
original_mask = original_mask
|
858 |
+
else:
|
859 |
+
original_mask = input_mask
|
860 |
+
|
861 |
+
if original_mask is None:
|
862 |
+
raise gr.Error('Please generate mask first')
|
863 |
+
|
864 |
+
if original_mask.ndim == 2:
|
865 |
+
original_mask = original_mask[:,:,None]
|
866 |
+
|
867 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
868 |
+
|
869 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
870 |
+
|
871 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
872 |
+
masked_image = masked_image.astype(original_image.dtype)
|
873 |
+
masked_image = Image.fromarray(masked_image)
|
874 |
+
|
875 |
+
|
876 |
+
if moved_mask.max() <= 1:
|
877 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
878 |
+
original_mask = moved_mask
|
879 |
+
|
880 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
881 |
+
|
882 |
+
|
883 |
+
def move_mask_up(input_image,
|
884 |
+
original_image,
|
885 |
+
original_mask,
|
886 |
+
moving_pixels,
|
887 |
+
resize_default,
|
888 |
+
aspect_ratio_name):
|
889 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
890 |
+
input_mask = np.asarray(alpha_mask)
|
891 |
+
|
892 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
893 |
+
if output_w == "" or output_h == "":
|
894 |
+
output_h, output_w = original_image.shape[:2]
|
895 |
+
if resize_default:
|
896 |
+
short_side = min(output_w, output_h)
|
897 |
+
scale_ratio = 640 / short_side
|
898 |
+
output_w = int(output_w * scale_ratio)
|
899 |
+
output_h = int(output_h * scale_ratio)
|
900 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
901 |
+
original_image = np.array(original_image)
|
902 |
+
if input_mask is not None:
|
903 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
904 |
+
input_mask = np.array(input_mask)
|
905 |
+
if original_mask is not None:
|
906 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
907 |
+
original_mask = np.array(original_mask)
|
908 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
909 |
+
else:
|
910 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
911 |
+
pass
|
912 |
+
else:
|
913 |
+
if resize_default:
|
914 |
+
short_side = min(output_w, output_h)
|
915 |
+
scale_ratio = 640 / short_side
|
916 |
+
output_w = int(output_w * scale_ratio)
|
917 |
+
output_h = int(output_h * scale_ratio)
|
918 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
919 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
920 |
+
original_image = np.array(original_image)
|
921 |
+
if input_mask is not None:
|
922 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
923 |
+
input_mask = np.array(input_mask)
|
924 |
+
if original_mask is not None:
|
925 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
926 |
+
original_mask = np.array(original_mask)
|
927 |
+
|
928 |
+
if input_mask.max() == 0:
|
929 |
+
original_mask = original_mask
|
930 |
+
else:
|
931 |
+
original_mask = input_mask
|
932 |
+
|
933 |
+
if original_mask is None:
|
934 |
+
raise gr.Error('Please generate mask first')
|
935 |
+
|
936 |
+
if original_mask.ndim == 2:
|
937 |
+
original_mask = original_mask[:,:,None]
|
938 |
+
|
939 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
940 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
941 |
+
|
942 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
943 |
+
masked_image = masked_image.astype(original_image.dtype)
|
944 |
+
masked_image = Image.fromarray(masked_image)
|
945 |
+
|
946 |
+
if moved_mask.max() <= 1:
|
947 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
948 |
+
original_mask = moved_mask
|
949 |
+
|
950 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
951 |
+
|
952 |
+
|
953 |
+
def move_mask_down(input_image,
|
954 |
+
original_image,
|
955 |
+
original_mask,
|
956 |
+
moving_pixels,
|
957 |
+
resize_default,
|
958 |
+
aspect_ratio_name):
|
959 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
960 |
+
input_mask = np.asarray(alpha_mask)
|
961 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
962 |
+
if output_w == "" or output_h == "":
|
963 |
+
output_h, output_w = original_image.shape[:2]
|
964 |
+
if resize_default:
|
965 |
+
short_side = min(output_w, output_h)
|
966 |
+
scale_ratio = 640 / short_side
|
967 |
+
output_w = int(output_w * scale_ratio)
|
968 |
+
output_h = int(output_h * scale_ratio)
|
969 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
970 |
+
original_image = np.array(original_image)
|
971 |
+
if input_mask is not None:
|
972 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
973 |
+
input_mask = np.array(input_mask)
|
974 |
+
if original_mask is not None:
|
975 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
976 |
+
original_mask = np.array(original_mask)
|
977 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
978 |
+
else:
|
979 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
980 |
+
pass
|
981 |
+
else:
|
982 |
+
if resize_default:
|
983 |
+
short_side = min(output_w, output_h)
|
984 |
+
scale_ratio = 640 / short_side
|
985 |
+
output_w = int(output_w * scale_ratio)
|
986 |
+
output_h = int(output_h * scale_ratio)
|
987 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
988 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
989 |
+
original_image = np.array(original_image)
|
990 |
+
if input_mask is not None:
|
991 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
992 |
+
input_mask = np.array(input_mask)
|
993 |
+
if original_mask is not None:
|
994 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
995 |
+
original_mask = np.array(original_mask)
|
996 |
+
|
997 |
+
if input_mask.max() == 0:
|
998 |
+
original_mask = original_mask
|
999 |
+
else:
|
1000 |
+
original_mask = input_mask
|
1001 |
+
|
1002 |
+
if original_mask is None:
|
1003 |
+
raise gr.Error('Please generate mask first')
|
1004 |
+
|
1005 |
+
if original_mask.ndim == 2:
|
1006 |
+
original_mask = original_mask[:,:,None]
|
1007 |
+
|
1008 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1009 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1010 |
+
|
1011 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1012 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1013 |
+
masked_image = Image.fromarray(masked_image)
|
1014 |
+
|
1015 |
+
if moved_mask.max() <= 1:
|
1016 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1017 |
+
original_mask = moved_mask
|
1018 |
+
|
1019 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1020 |
+
|
1021 |
+
|
1022 |
+
def invert_mask(input_image,
|
1023 |
+
original_image,
|
1024 |
+
original_mask,
|
1025 |
+
):
|
1026 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1027 |
+
input_mask = np.asarray(alpha_mask)
|
1028 |
+
if input_mask.max() == 0:
|
1029 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1030 |
+
else:
|
1031 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1032 |
+
|
1033 |
+
if original_mask is None:
|
1034 |
+
raise gr.Error('Please generate mask first')
|
1035 |
+
|
1036 |
+
original_mask = original_mask.squeeze()
|
1037 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1038 |
+
|
1039 |
+
if original_mask.ndim == 2:
|
1040 |
+
original_mask = original_mask[:,:,None]
|
1041 |
+
|
1042 |
+
if original_mask.max() <= 1:
|
1043 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1044 |
+
|
1045 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1046 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1047 |
+
masked_image = Image.fromarray(masked_image)
|
1048 |
+
|
1049 |
+
return [masked_image], [mask_image], original_mask, True
|
1050 |
+
|
1051 |
+
|
1052 |
+
def init_img(base,
|
1053 |
+
init_type,
|
1054 |
+
prompt,
|
1055 |
+
aspect_ratio,
|
1056 |
+
example_change_times
|
1057 |
+
):
|
1058 |
+
image_pil = base["background"].convert("RGB")
|
1059 |
+
original_image = np.array(image_pil)
|
1060 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1061 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1062 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1063 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1064 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1065 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1066 |
+
width, height = image_pil.size
|
1067 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1068 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1069 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1070 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1071 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1072 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1073 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1074 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1075 |
+
else:
|
1076 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1077 |
+
aspect_ratio = "Custom resolution"
|
1078 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1079 |
+
|
1080 |
+
|
1081 |
+
def reset_func(input_image,
|
1082 |
+
original_image,
|
1083 |
+
original_mask,
|
1084 |
+
prompt,
|
1085 |
+
target_prompt,
|
1086 |
+
):
|
1087 |
+
input_image = None
|
1088 |
+
original_image = None
|
1089 |
+
original_mask = None
|
1090 |
+
prompt = ''
|
1091 |
+
mask_gallery = []
|
1092 |
+
masked_gallery = []
|
1093 |
+
result_gallery = []
|
1094 |
+
target_prompt = ''
|
1095 |
+
if torch.cuda.is_available():
|
1096 |
+
torch.cuda.empty_cache()
|
1097 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1098 |
+
|
1099 |
+
|
1100 |
+
def update_example(example_type,
|
1101 |
+
prompt,
|
1102 |
+
example_change_times):
|
1103 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1104 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1105 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1106 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1107 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1108 |
+
width, height = image_pil.size
|
1109 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1110 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1111 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1112 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1113 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1114 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1115 |
+
|
1116 |
+
original_image = np.array(image_pil)
|
1117 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1118 |
+
aspect_ratio = "Custom resolution"
|
1119 |
+
example_change_times += 1
|
1120 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1121 |
+
|
1122 |
+
|
1123 |
+
def generate_target_prompt(input_image,
|
1124 |
+
original_image,
|
1125 |
+
prompt):
|
1126 |
+
# load example image
|
1127 |
+
if isinstance(original_image, str):
|
1128 |
+
original_image = input_image
|
1129 |
+
|
1130 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1131 |
+
vlm_processor,
|
1132 |
+
vlm_model,
|
1133 |
+
original_image,
|
1134 |
+
prompt,
|
1135 |
+
device)
|
1136 |
+
return prompt_after_apply_instruction
|
1137 |
+
|
1138 |
+
|
1139 |
+
|
1140 |
+
|
1141 |
+
def process_mask(input_image,
|
1142 |
+
original_image,
|
1143 |
+
prompt,
|
1144 |
+
resize_default,
|
1145 |
+
aspect_ratio_name):
|
1146 |
+
if original_image is None:
|
1147 |
+
raise gr.Error('Please upload the input image')
|
1148 |
+
if prompt is None:
|
1149 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
1150 |
+
|
1151 |
+
## load mask
|
1152 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1153 |
+
input_mask = np.array(alpha_mask)
|
1154 |
+
|
1155 |
+
# load example image
|
1156 |
+
if isinstance(original_image, str):
|
1157 |
+
original_image = input_image["background"]
|
1158 |
+
|
1159 |
+
if input_mask.max() == 0:
|
1160 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
1161 |
+
|
1162 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
1163 |
+
vlm_model,
|
1164 |
+
original_image,
|
1165 |
+
category,
|
1166 |
+
prompt,
|
1167 |
+
device)
|
1168 |
+
# original mask: h,w,1 [0, 255]
|
1169 |
+
original_mask = vlm_response_mask(
|
1170 |
+
vlm_processor,
|
1171 |
+
vlm_model,
|
1172 |
+
category,
|
1173 |
+
original_image,
|
1174 |
+
prompt,
|
1175 |
+
object_wait_for_edit,
|
1176 |
+
sam,
|
1177 |
+
sam_predictor,
|
1178 |
+
sam_automask_generator,
|
1179 |
+
groundingdino_model,
|
1180 |
+
device).astype(np.uint8)
|
1181 |
+
else:
|
1182 |
+
original_mask = input_mask.astype(np.uint8)
|
1183 |
+
category = None
|
1184 |
+
|
1185 |
+
## resize mask if needed
|
1186 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1187 |
+
if output_w == "" or output_h == "":
|
1188 |
+
output_h, output_w = original_image.shape[:2]
|
1189 |
+
if resize_default:
|
1190 |
+
short_side = min(output_w, output_h)
|
1191 |
+
scale_ratio = 640 / short_side
|
1192 |
+
output_w = int(output_w * scale_ratio)
|
1193 |
+
output_h = int(output_h * scale_ratio)
|
1194 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1195 |
+
original_image = np.array(original_image)
|
1196 |
+
if input_mask is not None:
|
1197 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1198 |
+
input_mask = np.array(input_mask)
|
1199 |
+
if original_mask is not None:
|
1200 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1201 |
+
original_mask = np.array(original_mask)
|
1202 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1203 |
+
else:
|
1204 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1205 |
+
pass
|
1206 |
+
else:
|
1207 |
+
if resize_default:
|
1208 |
+
short_side = min(output_w, output_h)
|
1209 |
+
scale_ratio = 640 / short_side
|
1210 |
+
output_w = int(output_w * scale_ratio)
|
1211 |
+
output_h = int(output_h * scale_ratio)
|
1212 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1213 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1214 |
+
original_image = np.array(original_image)
|
1215 |
+
if input_mask is not None:
|
1216 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1217 |
+
input_mask = np.array(input_mask)
|
1218 |
+
if original_mask is not None:
|
1219 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1220 |
+
original_mask = np.array(original_mask)
|
1221 |
+
|
1222 |
+
|
1223 |
+
if original_mask.ndim == 2:
|
1224 |
+
original_mask = original_mask[:,:,None]
|
1225 |
+
|
1226 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
1227 |
+
|
1228 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1229 |
+
masked_image = masked_image.astype(np.uint8)
|
1230 |
+
masked_image = Image.fromarray(masked_image)
|
1231 |
+
|
1232 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
1233 |
+
|
1234 |
+
|
1235 |
+
|
1236 |
+
def process(input_image,
|
1237 |
+
original_image,
|
1238 |
+
original_mask,
|
1239 |
+
prompt,
|
1240 |
+
negative_prompt,
|
1241 |
+
control_strength,
|
1242 |
+
seed,
|
1243 |
+
randomize_seed,
|
1244 |
+
guidance_scale,
|
1245 |
+
num_inference_steps,
|
1246 |
+
num_samples,
|
1247 |
+
blending,
|
1248 |
+
category,
|
1249 |
+
target_prompt,
|
1250 |
+
resize_default,
|
1251 |
+
aspect_ratio_name,
|
1252 |
+
invert_mask_state):
|
1253 |
+
if original_image is None:
|
1254 |
+
if input_image is None:
|
1255 |
+
raise gr.Error('Please upload the input image')
|
1256 |
+
else:
|
1257 |
+
image_pil = input_image["background"].convert("RGB")
|
1258 |
+
original_image = np.array(image_pil)
|
1259 |
+
if prompt is None or prompt == "":
|
1260 |
+
if target_prompt is None or target_prompt == "":
|
1261 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
1262 |
+
|
1263 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1264 |
+
input_mask = np.asarray(alpha_mask)
|
1265 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1266 |
+
if output_w == "" or output_h == "":
|
1267 |
+
output_h, output_w = original_image.shape[:2]
|
1268 |
+
|
1269 |
+
if resize_default:
|
1270 |
+
short_side = min(output_w, output_h)
|
1271 |
+
scale_ratio = 640 / short_side
|
1272 |
+
output_w = int(output_w * scale_ratio)
|
1273 |
+
output_h = int(output_h * scale_ratio)
|
1274 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1275 |
+
original_image = np.array(original_image)
|
1276 |
+
if input_mask is not None:
|
1277 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1278 |
+
input_mask = np.array(input_mask)
|
1279 |
+
if original_mask is not None:
|
1280 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1281 |
+
original_mask = np.array(original_mask)
|
1282 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1283 |
+
else:
|
1284 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1285 |
+
pass
|
1286 |
+
else:
|
1287 |
+
if resize_default:
|
1288 |
+
short_side = min(output_w, output_h)
|
1289 |
+
scale_ratio = 640 / short_side
|
1290 |
+
output_w = int(output_w * scale_ratio)
|
1291 |
+
output_h = int(output_h * scale_ratio)
|
1292 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1293 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1294 |
+
original_image = np.array(original_image)
|
1295 |
+
if input_mask is not None:
|
1296 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1297 |
+
input_mask = np.array(input_mask)
|
1298 |
+
if original_mask is not None:
|
1299 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1300 |
+
original_mask = np.array(original_mask)
|
1301 |
+
|
1302 |
+
if invert_mask_state:
|
1303 |
+
original_mask = original_mask
|
1304 |
+
else:
|
1305 |
+
if input_mask.max() == 0:
|
1306 |
+
original_mask = original_mask
|
1307 |
+
else:
|
1308 |
+
original_mask = input_mask
|
1309 |
+
|
1310 |
+
|
1311 |
+
# inpainting directly if target_prompt is not None
|
1312 |
+
if category is not None:
|
1313 |
+
pass
|
1314 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
1315 |
+
pass
|
1316 |
+
else:
|
1317 |
+
try:
|
1318 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
1319 |
+
except Exception as e:
|
1320 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1321 |
+
|
1322 |
+
|
1323 |
+
if original_mask is not None:
|
1324 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
1325 |
+
else:
|
1326 |
+
try:
|
1327 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
1328 |
+
vlm_processor,
|
1329 |
+
vlm_model,
|
1330 |
+
original_image,
|
1331 |
+
category,
|
1332 |
+
prompt,
|
1333 |
+
device)
|
1334 |
+
|
1335 |
+
original_mask = vlm_response_mask(vlm_processor,
|
1336 |
+
vlm_model,
|
1337 |
+
category,
|
1338 |
+
original_image,
|
1339 |
+
prompt,
|
1340 |
+
object_wait_for_edit,
|
1341 |
+
sam,
|
1342 |
+
sam_predictor,
|
1343 |
+
sam_automask_generator,
|
1344 |
+
groundingdino_model,
|
1345 |
+
device).astype(np.uint8)
|
1346 |
+
except Exception as e:
|
1347 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1348 |
+
|
1349 |
+
if original_mask.ndim == 2:
|
1350 |
+
original_mask = original_mask[:,:,None]
|
1351 |
+
|
1352 |
+
|
1353 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
1354 |
+
prompt_after_apply_instruction = target_prompt
|
1355 |
+
|
1356 |
+
else:
|
1357 |
+
try:
|
1358 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1359 |
+
vlm_processor,
|
1360 |
+
vlm_model,
|
1361 |
+
original_image,
|
1362 |
+
prompt,
|
1363 |
+
device)
|
1364 |
+
except Exception as e:
|
1365 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1366 |
+
|
1367 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
1368 |
+
|
1369 |
+
|
1370 |
+
with torch.autocast(device):
|
1371 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
1372 |
+
prompt_after_apply_instruction,
|
1373 |
+
original_mask,
|
1374 |
+
original_image,
|
1375 |
+
generator,
|
1376 |
+
num_inference_steps,
|
1377 |
+
guidance_scale,
|
1378 |
+
control_strength,
|
1379 |
+
negative_prompt,
|
1380 |
+
num_samples,
|
1381 |
+
blending)
|
1382 |
+
original_image = np.array(init_image_np)
|
1383 |
+
masked_image = original_image * (1 - (mask_np>0))
|
1384 |
+
masked_image = masked_image.astype(np.uint8)
|
1385 |
+
masked_image = Image.fromarray(masked_image)
|
1386 |
+
# Save the images (optional)
|
1387 |
+
# import uuid
|
1388 |
+
# uuid = str(uuid.uuid4())
|
1389 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
1390 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
1391 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
1392 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
1393 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
1394 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
1395 |
+
# gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
1396 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
1397 |
+
|
1398 |
+
|
1399 |
+
# 新增事件处理函数
|
1400 |
+
def generate_blip_description(input_image):
|
1401 |
+
if input_image is None:
|
1402 |
+
return "", "Input image cannot be None"
|
1403 |
+
try:
|
1404 |
+
image_pil = input_image["background"].convert("RGB")
|
1405 |
+
except KeyError:
|
1406 |
+
return "", "Input image missing 'background' key"
|
1407 |
+
except AttributeError as e:
|
1408 |
+
return "", f"Invalid image object: {str(e)}"
|
1409 |
+
try:
|
1410 |
+
description = generate_caption(blip_processor, blip_model, image_pil, device)
|
1411 |
+
return description, description # 同时更新state和显示组件
|
1412 |
+
except Exception as e:
|
1413 |
+
return "", f"Caption generation failed: {str(e)}"
|
1414 |
+
|
1415 |
+
from app.utils.utils import generate_caption
|
1416 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
1417 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1418 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
|
1419 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
|
1420 |
+
|
1421 |
+
|
1422 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
1423 |
+
global vlm_model, vlm_processor
|
1424 |
+
if vlm_model is not None:
|
1425 |
+
del vlm_model
|
1426 |
+
torch.cuda.empty_cache()
|
1427 |
+
try:
|
1428 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
|
1429 |
+
vlm_processor = ""
|
1430 |
+
response = vlm_model.chat.completions.create(
|
1431 |
+
model="deepseek-chat",
|
1432 |
+
messages=[
|
1433 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1434 |
+
{"role": "user", "content": "Hello."}
|
1435 |
+
]
|
1436 |
+
)
|
1437 |
+
response_str = response.choices[0].message.content
|
1438 |
+
|
1439 |
+
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
|
1440 |
+
except Exception as e:
|
1441 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
1442 |
+
|
1443 |
+
|
1444 |
+
def verify_deepseek_api():
|
1445 |
+
try:
|
1446 |
+
response = llm_model.chat.completions.create(
|
1447 |
+
model="deepseek-chat",
|
1448 |
+
messages=[
|
1449 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1450 |
+
{"role": "user", "content": "Hello."}
|
1451 |
+
]
|
1452 |
+
)
|
1453 |
+
response_str = response.choices[0].message.content
|
1454 |
+
|
1455 |
+
return True, "Success. " + response_str
|
1456 |
+
|
1457 |
+
except Exception as e:
|
1458 |
+
return False, "Invalid DeepSeek API Key"
|
1459 |
+
|
1460 |
+
|
1461 |
+
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
|
1462 |
+
try:
|
1463 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
1464 |
+
response = llm_model.chat.completions.create(
|
1465 |
+
model="deepseek-chat",
|
1466 |
+
messages=messages
|
1467 |
+
)
|
1468 |
+
response_str = response.choices[0].message.content
|
1469 |
+
return response_str
|
1470 |
+
except Exception as e:
|
1471 |
+
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1472 |
+
|
1473 |
+
|
1474 |
+
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
|
1475 |
+
try:
|
1476 |
+
messages = create_decomposed_query_messages_deepseek(integrated_query)
|
1477 |
+
response = llm_model.chat.completions.create(
|
1478 |
+
model="deepseek-chat",
|
1479 |
+
messages=messages
|
1480 |
+
)
|
1481 |
+
response_str = response.choices[0].message.content
|
1482 |
+
return response_str
|
1483 |
+
except Exception as e:
|
1484 |
+
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1485 |
+
|
1486 |
+
|
1487 |
+
def enhance_description(blip_description, prompt):
|
1488 |
+
try:
|
1489 |
+
if not prompt or not blip_description:
|
1490 |
+
print("Empty prompt or blip_description detected")
|
1491 |
+
return "", ""
|
1492 |
+
|
1493 |
+
print(f"Enhancing with prompt: {prompt}")
|
1494 |
+
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
|
1495 |
+
return enhanced_description, enhanced_description
|
1496 |
+
|
1497 |
+
except Exception as e:
|
1498 |
+
print(f"Enhancement failed: {str(e)}")
|
1499 |
+
return "Error occurred", "Error occurred"
|
1500 |
+
|
1501 |
+
|
1502 |
+
def decompose_description(enhanced_description):
|
1503 |
+
try:
|
1504 |
+
if not enhanced_description:
|
1505 |
+
print("Empty enhanced_description detected")
|
1506 |
+
return "", ""
|
1507 |
+
|
1508 |
+
print(f"Decomposing the enhanced description: {enhanced_description}")
|
1509 |
+
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
|
1510 |
+
return decomposed_description, decomposed_description
|
1511 |
+
|
1512 |
+
except Exception as e:
|
1513 |
+
print(f"Decomposition failed: {str(e)}")
|
1514 |
+
return "Error occurred", "Error occurred"
|
1515 |
+
|
1516 |
+
|
1517 |
+
@torch.no_grad()
|
1518 |
+
def mix_and_search(enhanced_text: str, gallery_images: list):
|
1519 |
+
# 获取最新生成的图像元组
|
1520 |
+
latest_item = gallery_images[-1] if gallery_images else None
|
1521 |
+
|
1522 |
+
# 初始化特征列表
|
1523 |
+
features = []
|
1524 |
+
|
1525 |
+
# 图像特征提取
|
1526 |
+
if latest_item and isinstance(latest_item, tuple):
|
1527 |
+
try:
|
1528 |
+
image_path = latest_item[0]
|
1529 |
+
pil_image = Image.open(image_path).convert("RGB")
|
1530 |
+
|
1531 |
+
# 使用 CLIPProcessor 处理图像
|
1532 |
+
image_inputs = clip_processor(
|
1533 |
+
images=pil_image,
|
1534 |
+
return_tensors="pt"
|
1535 |
+
).to(device)
|
1536 |
+
|
1537 |
+
image_features = clip_model.get_image_features(**image_inputs)
|
1538 |
+
features.append(F.normalize(image_features, dim=-1))
|
1539 |
+
except Exception as e:
|
1540 |
+
print(f"图像处理失败: {str(e)}")
|
1541 |
+
|
1542 |
+
# 文本特征提取
|
1543 |
+
if enhanced_text.strip():
|
1544 |
+
text_inputs = clip_processor(
|
1545 |
+
text=enhanced_text,
|
1546 |
+
return_tensors="pt",
|
1547 |
+
padding=True,
|
1548 |
+
truncation=True
|
1549 |
+
).to(device)
|
1550 |
+
|
1551 |
+
text_features = clip_model.get_text_features(**text_inputs)
|
1552 |
+
features.append(F.normalize(text_features, dim=-1))
|
1553 |
+
|
1554 |
+
if not features:
|
1555 |
+
return []
|
1556 |
+
|
1557 |
+
|
1558 |
+
# 特征融合与检索
|
1559 |
+
mixed = sum(features) / len(features)
|
1560 |
+
mixed = F.normalize(mixed, dim=-1)
|
1561 |
+
|
1562 |
+
# 加载Faiss索引和图片路径映射
|
1563 |
+
index_path = "/home/zt/data/open-images/train/knn.index"
|
1564 |
+
input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
|
1565 |
+
base_image_dir = Path("/home/zt/data/open-images/train/")
|
1566 |
+
|
1567 |
+
# 按文件名中的数字排序并直接读取parquet文件
|
1568 |
+
parquet_files = sorted(
|
1569 |
+
input_data_dir.glob('*.parquet'),
|
1570 |
+
key=lambda x: int(x.stem.split("_")[-1])
|
1571 |
+
)
|
1572 |
+
|
1573 |
+
# 合并所有parquet数据
|
1574 |
+
dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
|
1575 |
+
df = pd.concat(dfs, ignore_index=True)
|
1576 |
+
image_paths = df["image_path"].tolist()
|
1577 |
+
|
1578 |
+
# 读取Faiss索引
|
1579 |
+
index = faiss.read_index(index_path)
|
1580 |
+
assert mixed.shape[1] == index.d, "特征维度不匹配"
|
1581 |
+
|
1582 |
+
# 执行检索
|
1583 |
+
mixed = mixed.cpu().detach().numpy().astype('float32')
|
1584 |
+
distances, indices = index.search(mixed, 50)
|
1585 |
+
|
1586 |
+
# 获取并验证图片路径
|
1587 |
+
retrieved_images = []
|
1588 |
+
for idx in indices[0]:
|
1589 |
+
if 0 <= idx < len(image_paths):
|
1590 |
+
img_path = base_image_dir / image_paths[idx]
|
1591 |
+
try:
|
1592 |
+
if img_path.exists():
|
1593 |
+
retrieved_images.append(Image.open(img_path).convert("RGB"))
|
1594 |
+
else:
|
1595 |
+
print(f"警告:文件缺失 {img_path}")
|
1596 |
+
except Exception as e:
|
1597 |
+
print(f"图片加载失败: {str(e)}")
|
1598 |
+
|
1599 |
+
return retrieved_images if retrieved_images else ([])
|
1600 |
+
|
1601 |
+
|
1602 |
+
|
1603 |
+
block = gr.Blocks(
|
1604 |
+
theme=gr.themes.Soft(
|
1605 |
+
radius_size=gr.themes.sizes.radius_none,
|
1606 |
+
text_size=gr.themes.sizes.text_md
|
1607 |
+
)
|
1608 |
+
)
|
1609 |
+
|
1610 |
+
with block as demo:
|
1611 |
+
with gr.Row():
|
1612 |
+
with gr.Column():
|
1613 |
+
gr.HTML(head)
|
1614 |
+
gr.Markdown(descriptions)
|
1615 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1616 |
+
with gr.Row(equal_height=True):
|
1617 |
+
gr.Markdown(instructions)
|
1618 |
+
|
1619 |
+
original_image = gr.State(value=None)
|
1620 |
+
original_mask = gr.State(value=None)
|
1621 |
+
category = gr.State(value=None)
|
1622 |
+
status = gr.State(value=None)
|
1623 |
+
invert_mask_state = gr.State(value=False)
|
1624 |
+
example_change_times = gr.State(value=0)
|
1625 |
+
deepseek_verified = gr.State(value=False)
|
1626 |
+
blip_description = gr.State(value="")
|
1627 |
+
enhanced_description = gr.State(value="")
|
1628 |
+
decomposed_description = gr.State(value="")
|
1629 |
+
|
1630 |
+
with gr.Row():
|
1631 |
+
with gr.Column():
|
1632 |
+
with gr.Group():
|
1633 |
+
input_image = gr.ImageEditor(
|
1634 |
+
label="参考图像",
|
1635 |
+
type="pil",
|
1636 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1637 |
+
layers = False,
|
1638 |
+
interactive=True,
|
1639 |
+
# height=1024,
|
1640 |
+
height=420,
|
1641 |
+
sources=["upload"],
|
1642 |
+
placeholder="🫧 点击此处或下面的图标上传图像 🫧",
|
1643 |
+
)
|
1644 |
+
prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期...", value="",lines=1)
|
1645 |
+
|
1646 |
+
with gr.Group():
|
1647 |
+
mask_button = gr.Button("💎 掩膜生成")
|
1648 |
+
with gr.Row():
|
1649 |
+
invert_mask_button = gr.Button("👐 掩膜翻转")
|
1650 |
+
random_mask_button = gr.Button("⭕️ 随机掩膜")
|
1651 |
+
with gr.Row():
|
1652 |
+
masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
|
1653 |
+
mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
|
1654 |
+
|
1655 |
+
|
1656 |
+
|
1657 |
+
with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
|
1658 |
+
dilation_size = gr.Slider(
|
1659 |
+
label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
|
1660 |
+
)
|
1661 |
+
with gr.Row():
|
1662 |
+
dilation_mask_button = gr.Button("放大掩膜")
|
1663 |
+
erosion_mask_button = gr.Button("缩小掩膜")
|
1664 |
+
|
1665 |
+
moving_pixels = gr.Slider(
|
1666 |
+
label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1667 |
+
)
|
1668 |
+
with gr.Row():
|
1669 |
+
move_left_button = gr.Button("左移")
|
1670 |
+
move_right_button = gr.Button("右移")
|
1671 |
+
with gr.Row():
|
1672 |
+
move_up_button = gr.Button("上移")
|
1673 |
+
move_down_button = gr.Button("下移")
|
1674 |
+
|
1675 |
+
|
1676 |
+
|
1677 |
+
with gr.Column():
|
1678 |
+
with gr.Row():
|
1679 |
+
deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=2, container=False)
|
1680 |
+
verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
|
1681 |
+
blip_output = gr.Textbox(label="1. 原图描述(BLIP生成)", placeholder="🖼️ 上传图片后自动生成图片描述...", lines=2, interactive=True)
|
1682 |
+
with gr.Row():
|
1683 |
+
enhanced_output = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击右侧按钮生成增强描述...")
|
1684 |
+
enhance_button = gr.Button("✨ 智能整合")
|
1685 |
+
|
1686 |
+
with gr.Row():
|
1687 |
+
decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述...")
|
1688 |
+
decompose_button = gr.Button("🔧 结构分解")
|
1689 |
+
|
1690 |
+
|
1691 |
+
|
1692 |
+
with gr.Group():
|
1693 |
+
run_button = gr.Button("💫 图像编辑")
|
1694 |
+
result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
|
1695 |
+
|
1696 |
+
with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
|
1697 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1698 |
+
|
1699 |
+
with gr.Group():
|
1700 |
+
with gr.Row():
|
1701 |
+
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1702 |
+
GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1703 |
+
GPT4o_KEY_submit = gr.Button("🔑 验证密钥")
|
1704 |
+
|
1705 |
+
aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1706 |
+
resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
|
1707 |
+
base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1708 |
+
negative_prompt = gr.Text(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
|
1709 |
+
control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
|
1710 |
+
with gr.Group():
|
1711 |
+
seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
|
1712 |
+
randomize_seed = gr.Checkbox(label="随机种子", value=False)
|
1713 |
+
blending = gr.Checkbox(label="混合模式", value=True)
|
1714 |
+
num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
|
1715 |
+
with gr.Group():
|
1716 |
+
with gr.Row():
|
1717 |
+
guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
|
1718 |
+
num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
|
1719 |
+
target_prompt = gr.Text(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
|
1720 |
+
|
1721 |
+
|
1722 |
+
|
1723 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1724 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1725 |
+
|
1726 |
+
with gr.Row():
|
1727 |
+
reset_button = gr.Button("Reset")
|
1728 |
+
retrieve_button = gr.Button("🔍 开始检索")
|
1729 |
+
|
1730 |
+
with gr.Row():
|
1731 |
+
retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=800)
|
1732 |
+
|
1733 |
+
|
1734 |
+
with gr.Row():
|
1735 |
+
example = gr.Examples(
|
1736 |
+
label="Quick Example",
|
1737 |
+
examples=EXAMPLES,
|
1738 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1739 |
+
examples_per_page=10,
|
1740 |
+
cache_examples=False,
|
1741 |
+
)
|
1742 |
+
|
1743 |
+
|
1744 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1745 |
+
with gr.Row(equal_height=True):
|
1746 |
+
gr.Markdown(tips)
|
1747 |
+
|
1748 |
+
with gr.Row():
|
1749 |
+
gr.Markdown(citation)
|
1750 |
+
|
1751 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1752 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1753 |
+
input_image.upload(
|
1754 |
+
init_img,
|
1755 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1756 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1757 |
+
|
1758 |
+
)
|
1759 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1760 |
+
|
1761 |
+
|
1762 |
+
## vlm and base model dropdown
|
1763 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1764 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1765 |
+
|
1766 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1767 |
+
|
1768 |
+
|
1769 |
+
ips=[input_image,
|
1770 |
+
original_image,
|
1771 |
+
original_mask,
|
1772 |
+
prompt,
|
1773 |
+
negative_prompt,
|
1774 |
+
control_strength,
|
1775 |
+
seed,
|
1776 |
+
randomize_seed,
|
1777 |
+
guidance_scale,
|
1778 |
+
num_inference_steps,
|
1779 |
+
num_samples,
|
1780 |
+
blending,
|
1781 |
+
category,
|
1782 |
+
target_prompt,
|
1783 |
+
resize_default,
|
1784 |
+
aspect_ratio,
|
1785 |
+
invert_mask_state]
|
1786 |
+
|
1787 |
+
## run brushedit
|
1788 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1789 |
+
|
1790 |
+
|
1791 |
+
## mask func
|
1792 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1793 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1794 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1795 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1796 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1797 |
+
|
1798 |
+
## reset func
|
1799 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1800 |
+
|
1801 |
+
input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
|
1802 |
+
verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
|
1803 |
+
enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
|
1804 |
+
decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
|
1805 |
+
retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
|
1806 |
+
|
1807 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1808 |
+
|
1809 |
+
|
brushedit_app_new_doable.py
ADDED
@@ -0,0 +1,1860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
import pandas as pd
|
10 |
+
import concurrent.futures
|
11 |
+
import faiss
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
import torch.nn.functional as F # 新增此行
|
17 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
18 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
19 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
20 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
21 |
+
|
22 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
23 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
24 |
+
from diffusers.image_processor import VaeImageProcessor
|
25 |
+
|
26 |
+
|
27 |
+
from app.src.vlm_pipeline import (
|
28 |
+
vlm_response_editing_type,
|
29 |
+
vlm_response_object_wait_for_edit,
|
30 |
+
vlm_response_mask,
|
31 |
+
vlm_response_prompt_after_apply_instruction
|
32 |
+
)
|
33 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
34 |
+
from app.utils.utils import load_grounding_dino_model
|
35 |
+
|
36 |
+
from app.src.vlm_template import vlms_template
|
37 |
+
from app.src.base_model_template import base_models_template
|
38 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
39 |
+
|
40 |
+
from openai import OpenAI
|
41 |
+
base_openai_url = "https://api.deepseek.com/"
|
42 |
+
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
|
43 |
+
|
44 |
+
|
45 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
46 |
+
from transformers import CLIPProcessor, CLIPModel
|
47 |
+
|
48 |
+
from app.deepseek.instructions import (
|
49 |
+
create_apply_editing_messages_deepseek,
|
50 |
+
create_decomposed_query_messages_deepseek
|
51 |
+
)
|
52 |
+
from clip_retrieval.clip_client import ClipClient
|
53 |
+
|
54 |
+
#### Description ####
|
55 |
+
logo = r"""
|
56 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
57 |
+
"""
|
58 |
+
head = r"""
|
59 |
+
<div style="text-align: center;">
|
60 |
+
<h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
|
61 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
62 |
+
<a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
63 |
+
<a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
64 |
+
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
65 |
+
|
66 |
+
</div>
|
67 |
+
</br>
|
68 |
+
</div>
|
69 |
+
"""
|
70 |
+
descriptions = r"""
|
71 |
+
Demo for ZS-CIR"""
|
72 |
+
|
73 |
+
instructions = r"""
|
74 |
+
Demo for ZS-CIR"""
|
75 |
+
|
76 |
+
tips = r"""
|
77 |
+
Demo for ZS-CIR
|
78 |
+
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
citation = r"""
|
84 |
+
Demo for ZS-CIR"""
|
85 |
+
|
86 |
+
# - - - - - examples - - - - - #
|
87 |
+
EXAMPLES = [
|
88 |
+
|
89 |
+
[
|
90 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
91 |
+
"add a magic hat on frog head.",
|
92 |
+
642087011,
|
93 |
+
"frog",
|
94 |
+
"frog",
|
95 |
+
True,
|
96 |
+
False,
|
97 |
+
"GPT4-o (Highly Recommended)"
|
98 |
+
],
|
99 |
+
[
|
100 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
101 |
+
"replace the background to ancient China.",
|
102 |
+
648464818,
|
103 |
+
"chinese_girl",
|
104 |
+
"chinese_girl",
|
105 |
+
True,
|
106 |
+
False,
|
107 |
+
"GPT4-o (Highly Recommended)"
|
108 |
+
],
|
109 |
+
[
|
110 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
111 |
+
"remove the deer.",
|
112 |
+
648464818,
|
113 |
+
"angel_christmas",
|
114 |
+
"angel_christmas",
|
115 |
+
False,
|
116 |
+
False,
|
117 |
+
"GPT4-o (Highly Recommended)"
|
118 |
+
],
|
119 |
+
[
|
120 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
121 |
+
"add a wreath on head.",
|
122 |
+
648464818,
|
123 |
+
"sunflower_girl",
|
124 |
+
"sunflower_girl",
|
125 |
+
True,
|
126 |
+
False,
|
127 |
+
"GPT4-o (Highly Recommended)"
|
128 |
+
],
|
129 |
+
[
|
130 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
131 |
+
"add a butterfly fairy.",
|
132 |
+
648464818,
|
133 |
+
"girl_on_sun",
|
134 |
+
"girl_on_sun",
|
135 |
+
True,
|
136 |
+
False,
|
137 |
+
"GPT4-o (Highly Recommended)"
|
138 |
+
],
|
139 |
+
[
|
140 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
141 |
+
"remove the christmas hat.",
|
142 |
+
642087011,
|
143 |
+
"spider_man_rm",
|
144 |
+
"spider_man_rm",
|
145 |
+
False,
|
146 |
+
False,
|
147 |
+
"GPT4-o (Highly Recommended)"
|
148 |
+
],
|
149 |
+
[
|
150 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
151 |
+
"remove the flower.",
|
152 |
+
642087011,
|
153 |
+
"anime_flower",
|
154 |
+
"anime_flower",
|
155 |
+
False,
|
156 |
+
False,
|
157 |
+
"GPT4-o (Highly Recommended)"
|
158 |
+
],
|
159 |
+
[
|
160 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
161 |
+
"replace the clothes to a delicated floral skirt.",
|
162 |
+
648464818,
|
163 |
+
"chenduling",
|
164 |
+
"chenduling",
|
165 |
+
True,
|
166 |
+
False,
|
167 |
+
"GPT4-o (Highly Recommended)"
|
168 |
+
],
|
169 |
+
[
|
170 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
171 |
+
"make the hedgehog in Italy.",
|
172 |
+
648464818,
|
173 |
+
"hedgehog_rp_bg",
|
174 |
+
"hedgehog_rp_bg",
|
175 |
+
True,
|
176 |
+
False,
|
177 |
+
"GPT4-o (Highly Recommended)"
|
178 |
+
],
|
179 |
+
|
180 |
+
]
|
181 |
+
|
182 |
+
INPUT_IMAGE_PATH = {
|
183 |
+
"frog": "./assets/frog/frog.jpeg",
|
184 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
185 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
186 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
187 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
188 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
189 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
190 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
191 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
192 |
+
}
|
193 |
+
MASK_IMAGE_PATH = {
|
194 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
195 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
196 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
197 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
198 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
199 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
200 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
201 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
202 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
203 |
+
}
|
204 |
+
MASKED_IMAGE_PATH = {
|
205 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
206 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
207 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
208 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
209 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
210 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
211 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
212 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
213 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
214 |
+
}
|
215 |
+
OUTPUT_IMAGE_PATH = {
|
216 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
217 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
218 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
219 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
220 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
221 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
222 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
223 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
224 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
225 |
+
}
|
226 |
+
|
227 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
228 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
229 |
+
|
230 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
231 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
232 |
+
|
233 |
+
|
234 |
+
BASE_MODELS = list(base_models_template.keys())
|
235 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
236 |
+
|
237 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
238 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
239 |
+
|
240 |
+
|
241 |
+
## init device
|
242 |
+
try:
|
243 |
+
if torch.cuda.is_available():
|
244 |
+
device = "cuda"
|
245 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
246 |
+
device = "mps"
|
247 |
+
else:
|
248 |
+
device = "cpu"
|
249 |
+
except:
|
250 |
+
device = "cpu"
|
251 |
+
|
252 |
+
# ## init torch dtype
|
253 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
254 |
+
# torch_dtype = torch.bfloat16
|
255 |
+
# else:
|
256 |
+
# torch_dtype = torch.float16
|
257 |
+
|
258 |
+
# if device == "mps":
|
259 |
+
# torch_dtype = torch.float16
|
260 |
+
|
261 |
+
torch_dtype = torch.float16
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
# download hf models
|
266 |
+
BrushEdit_path = "models/"
|
267 |
+
if not os.path.exists(BrushEdit_path):
|
268 |
+
BrushEdit_path = snapshot_download(
|
269 |
+
repo_id="TencentARC/BrushEdit",
|
270 |
+
local_dir=BrushEdit_path,
|
271 |
+
token=os.getenv("HF_TOKEN"),
|
272 |
+
)
|
273 |
+
|
274 |
+
## init default VLM
|
275 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
276 |
+
if vlm_processor != "" and vlm_model != "":
|
277 |
+
vlm_model.to(device)
|
278 |
+
else:
|
279 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
280 |
+
|
281 |
+
## init default LLM
|
282 |
+
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
|
283 |
+
|
284 |
+
## init base model
|
285 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
286 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
287 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
288 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
289 |
+
|
290 |
+
|
291 |
+
# input brushnetX ckpt path
|
292 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
293 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
294 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
295 |
+
)
|
296 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
297 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
298 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
299 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
300 |
+
pipe.enable_model_cpu_offload()
|
301 |
+
|
302 |
+
|
303 |
+
## init SAM
|
304 |
+
sam = build_sam(checkpoint=sam_path)
|
305 |
+
sam.to(device=device)
|
306 |
+
sam_predictor = SamPredictor(sam)
|
307 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
308 |
+
|
309 |
+
## init groundingdino_model
|
310 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
311 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
312 |
+
|
313 |
+
## Ordinary function
|
314 |
+
def crop_and_resize(image: Image.Image,
|
315 |
+
target_width: int,
|
316 |
+
target_height: int) -> Image.Image:
|
317 |
+
"""
|
318 |
+
Crops and resizes an image while preserving the aspect ratio.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
322 |
+
target_width (int): Target width of the output image.
|
323 |
+
target_height (int): Target height of the output image.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
Image.Image: Cropped and resized image.
|
327 |
+
"""
|
328 |
+
# Original dimensions
|
329 |
+
original_width, original_height = image.size
|
330 |
+
original_aspect = original_width / original_height
|
331 |
+
target_aspect = target_width / target_height
|
332 |
+
|
333 |
+
# Calculate crop box to maintain aspect ratio
|
334 |
+
if original_aspect > target_aspect:
|
335 |
+
# Crop horizontally
|
336 |
+
new_width = int(original_height * target_aspect)
|
337 |
+
new_height = original_height
|
338 |
+
left = (original_width - new_width) / 2
|
339 |
+
top = 0
|
340 |
+
right = left + new_width
|
341 |
+
bottom = original_height
|
342 |
+
else:
|
343 |
+
# Crop vertically
|
344 |
+
new_width = original_width
|
345 |
+
new_height = int(original_width / target_aspect)
|
346 |
+
left = 0
|
347 |
+
top = (original_height - new_height) / 2
|
348 |
+
right = original_width
|
349 |
+
bottom = top + new_height
|
350 |
+
|
351 |
+
# Crop and resize
|
352 |
+
cropped_image = image.crop((left, top, right, bottom))
|
353 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
354 |
+
return resized_image
|
355 |
+
|
356 |
+
|
357 |
+
## Ordinary function
|
358 |
+
def resize(image: Image.Image,
|
359 |
+
target_width: int,
|
360 |
+
target_height: int) -> Image.Image:
|
361 |
+
"""
|
362 |
+
Crops and resizes an image while preserving the aspect ratio.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
366 |
+
target_width (int): Target width of the output image.
|
367 |
+
target_height (int): Target height of the output image.
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Image.Image: Cropped and resized image.
|
371 |
+
"""
|
372 |
+
# Original dimensions
|
373 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
374 |
+
return resized_image
|
375 |
+
|
376 |
+
|
377 |
+
def move_mask_func(mask, direction, units):
|
378 |
+
binary_mask = mask.squeeze()>0
|
379 |
+
rows, cols = binary_mask.shape
|
380 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
381 |
+
|
382 |
+
if direction == 'down':
|
383 |
+
# move down
|
384 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
385 |
+
|
386 |
+
elif direction == 'up':
|
387 |
+
# move up
|
388 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
389 |
+
|
390 |
+
elif direction == 'right':
|
391 |
+
# move left
|
392 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
393 |
+
|
394 |
+
elif direction == 'left':
|
395 |
+
# move right
|
396 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
397 |
+
|
398 |
+
return moved_mask
|
399 |
+
|
400 |
+
|
401 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
402 |
+
# Randomly select the size of dilation
|
403 |
+
binary_mask = mask.squeeze()>0
|
404 |
+
|
405 |
+
if dilation_type == 'square_dilation':
|
406 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
407 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
408 |
+
elif dilation_type == 'square_erosion':
|
409 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
410 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
411 |
+
elif dilation_type == 'bounding_box':
|
412 |
+
# find the most left top and left bottom point
|
413 |
+
rows, cols = np.where(binary_mask)
|
414 |
+
if len(rows) == 0 or len(cols) == 0:
|
415 |
+
return mask # return original mask if no valid points
|
416 |
+
|
417 |
+
min_row = np.min(rows)
|
418 |
+
max_row = np.max(rows)
|
419 |
+
min_col = np.min(cols)
|
420 |
+
max_col = np.max(cols)
|
421 |
+
|
422 |
+
# create a bounding box
|
423 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
424 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
425 |
+
|
426 |
+
elif dilation_type == 'bounding_ellipse':
|
427 |
+
# find the most left top and left bottom point
|
428 |
+
rows, cols = np.where(binary_mask)
|
429 |
+
if len(rows) == 0 or len(cols) == 0:
|
430 |
+
return mask # return original mask if no valid points
|
431 |
+
|
432 |
+
min_row = np.min(rows)
|
433 |
+
max_row = np.max(rows)
|
434 |
+
min_col = np.min(cols)
|
435 |
+
max_col = np.max(cols)
|
436 |
+
|
437 |
+
# calculate the center and axis length of the ellipse
|
438 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
439 |
+
a = (max_col - min_col) // 2 # half long axis
|
440 |
+
b = (max_row - min_row) // 2 # half short axis
|
441 |
+
|
442 |
+
# create a bounding ellipse
|
443 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
444 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
445 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
446 |
+
dilated_mask[ellipse_mask] = True
|
447 |
+
else:
|
448 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
449 |
+
|
450 |
+
# use binary dilation
|
451 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
452 |
+
return dilated_mask
|
453 |
+
|
454 |
+
|
455 |
+
## Gradio component function
|
456 |
+
def update_vlm_model(vlm_name):
|
457 |
+
global vlm_model, vlm_processor
|
458 |
+
if vlm_model is not None:
|
459 |
+
del vlm_model
|
460 |
+
torch.cuda.empty_cache()
|
461 |
+
|
462 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
463 |
+
|
464 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
465 |
+
if vlm_type == "llava-next":
|
466 |
+
if vlm_processor != "" and vlm_model != "":
|
467 |
+
vlm_model.to(device)
|
468 |
+
return vlm_model_dropdown
|
469 |
+
else:
|
470 |
+
if os.path.exists(vlm_local_path):
|
471 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
472 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
473 |
+
else:
|
474 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
475 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
476 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
477 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
478 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
479 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
480 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
481 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
482 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
483 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
484 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
485 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
486 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
487 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
488 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
489 |
+
elif vlm_type == "qwen2-vl":
|
490 |
+
if vlm_processor != "" and vlm_model != "":
|
491 |
+
vlm_model.to(device)
|
492 |
+
return vlm_model_dropdown
|
493 |
+
else:
|
494 |
+
if os.path.exists(vlm_local_path):
|
495 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
496 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
497 |
+
else:
|
498 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
499 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
500 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
501 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
502 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
503 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
504 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
505 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
506 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
507 |
+
elif vlm_type == "openai":
|
508 |
+
pass
|
509 |
+
return "success"
|
510 |
+
|
511 |
+
|
512 |
+
def update_base_model(base_model_name):
|
513 |
+
global pipe
|
514 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
515 |
+
if pipe is not None:
|
516 |
+
del pipe
|
517 |
+
torch.cuda.empty_cache()
|
518 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
519 |
+
if pipe != "":
|
520 |
+
pipe.to(device)
|
521 |
+
else:
|
522 |
+
if os.path.exists(base_model_path):
|
523 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
524 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
525 |
+
)
|
526 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
527 |
+
pipe.enable_model_cpu_offload()
|
528 |
+
else:
|
529 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
530 |
+
return "success"
|
531 |
+
|
532 |
+
|
533 |
+
def process_random_mask(input_image,
|
534 |
+
original_image,
|
535 |
+
original_mask,
|
536 |
+
resize_default,
|
537 |
+
aspect_ratio_name,
|
538 |
+
):
|
539 |
+
|
540 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
541 |
+
input_mask = np.asarray(alpha_mask)
|
542 |
+
|
543 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
544 |
+
if output_w == "" or output_h == "":
|
545 |
+
output_h, output_w = original_image.shape[:2]
|
546 |
+
if resize_default:
|
547 |
+
short_side = min(output_w, output_h)
|
548 |
+
scale_ratio = 640 / short_side
|
549 |
+
output_w = int(output_w * scale_ratio)
|
550 |
+
output_h = int(output_h * scale_ratio)
|
551 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
552 |
+
original_image = np.array(original_image)
|
553 |
+
if input_mask is not None:
|
554 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
555 |
+
input_mask = np.array(input_mask)
|
556 |
+
if original_mask is not None:
|
557 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
558 |
+
original_mask = np.array(original_mask)
|
559 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
560 |
+
else:
|
561 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
562 |
+
pass
|
563 |
+
else:
|
564 |
+
if resize_default:
|
565 |
+
short_side = min(output_w, output_h)
|
566 |
+
scale_ratio = 640 / short_side
|
567 |
+
output_w = int(output_w * scale_ratio)
|
568 |
+
output_h = int(output_h * scale_ratio)
|
569 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
570 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
571 |
+
original_image = np.array(original_image)
|
572 |
+
if input_mask is not None:
|
573 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
574 |
+
input_mask = np.array(input_mask)
|
575 |
+
if original_mask is not None:
|
576 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
577 |
+
original_mask = np.array(original_mask)
|
578 |
+
|
579 |
+
|
580 |
+
if input_mask.max() == 0:
|
581 |
+
original_mask = original_mask
|
582 |
+
else:
|
583 |
+
original_mask = input_mask
|
584 |
+
|
585 |
+
if original_mask is None:
|
586 |
+
raise gr.Error('Please generate mask first')
|
587 |
+
|
588 |
+
if original_mask.ndim == 2:
|
589 |
+
original_mask = original_mask[:,:,None]
|
590 |
+
|
591 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
592 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
593 |
+
|
594 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
595 |
+
|
596 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
597 |
+
masked_image = masked_image.astype(original_image.dtype)
|
598 |
+
masked_image = Image.fromarray(masked_image)
|
599 |
+
|
600 |
+
|
601 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
602 |
+
|
603 |
+
|
604 |
+
def process_dilation_mask(input_image,
|
605 |
+
original_image,
|
606 |
+
original_mask,
|
607 |
+
resize_default,
|
608 |
+
aspect_ratio_name,
|
609 |
+
dilation_size=20):
|
610 |
+
|
611 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
612 |
+
input_mask = np.asarray(alpha_mask)
|
613 |
+
|
614 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
615 |
+
if output_w == "" or output_h == "":
|
616 |
+
output_h, output_w = original_image.shape[:2]
|
617 |
+
if resize_default:
|
618 |
+
short_side = min(output_w, output_h)
|
619 |
+
scale_ratio = 640 / short_side
|
620 |
+
output_w = int(output_w * scale_ratio)
|
621 |
+
output_h = int(output_h * scale_ratio)
|
622 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
623 |
+
original_image = np.array(original_image)
|
624 |
+
if input_mask is not None:
|
625 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
626 |
+
input_mask = np.array(input_mask)
|
627 |
+
if original_mask is not None:
|
628 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
629 |
+
original_mask = np.array(original_mask)
|
630 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
631 |
+
else:
|
632 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
633 |
+
pass
|
634 |
+
else:
|
635 |
+
if resize_default:
|
636 |
+
short_side = min(output_w, output_h)
|
637 |
+
scale_ratio = 640 / short_side
|
638 |
+
output_w = int(output_w * scale_ratio)
|
639 |
+
output_h = int(output_h * scale_ratio)
|
640 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
641 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
642 |
+
original_image = np.array(original_image)
|
643 |
+
if input_mask is not None:
|
644 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
645 |
+
input_mask = np.array(input_mask)
|
646 |
+
if original_mask is not None:
|
647 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
648 |
+
original_mask = np.array(original_mask)
|
649 |
+
|
650 |
+
if input_mask.max() == 0:
|
651 |
+
original_mask = original_mask
|
652 |
+
else:
|
653 |
+
original_mask = input_mask
|
654 |
+
|
655 |
+
if original_mask is None:
|
656 |
+
raise gr.Error('Please generate mask first')
|
657 |
+
|
658 |
+
if original_mask.ndim == 2:
|
659 |
+
original_mask = original_mask[:,:,None]
|
660 |
+
|
661 |
+
dilation_type = np.random.choice(['square_dilation'])
|
662 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
663 |
+
|
664 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
665 |
+
|
666 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
667 |
+
masked_image = masked_image.astype(original_image.dtype)
|
668 |
+
masked_image = Image.fromarray(masked_image)
|
669 |
+
|
670 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
671 |
+
|
672 |
+
|
673 |
+
def process_erosion_mask(input_image,
|
674 |
+
original_image,
|
675 |
+
original_mask,
|
676 |
+
resize_default,
|
677 |
+
aspect_ratio_name,
|
678 |
+
dilation_size=20):
|
679 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
680 |
+
input_mask = np.asarray(alpha_mask)
|
681 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
682 |
+
if output_w == "" or output_h == "":
|
683 |
+
output_h, output_w = original_image.shape[:2]
|
684 |
+
if resize_default:
|
685 |
+
short_side = min(output_w, output_h)
|
686 |
+
scale_ratio = 640 / short_side
|
687 |
+
output_w = int(output_w * scale_ratio)
|
688 |
+
output_h = int(output_h * scale_ratio)
|
689 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
690 |
+
original_image = np.array(original_image)
|
691 |
+
if input_mask is not None:
|
692 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
693 |
+
input_mask = np.array(input_mask)
|
694 |
+
if original_mask is not None:
|
695 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
696 |
+
original_mask = np.array(original_mask)
|
697 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
698 |
+
else:
|
699 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
700 |
+
pass
|
701 |
+
else:
|
702 |
+
if resize_default:
|
703 |
+
short_side = min(output_w, output_h)
|
704 |
+
scale_ratio = 640 / short_side
|
705 |
+
output_w = int(output_w * scale_ratio)
|
706 |
+
output_h = int(output_h * scale_ratio)
|
707 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
708 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
709 |
+
original_image = np.array(original_image)
|
710 |
+
if input_mask is not None:
|
711 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
712 |
+
input_mask = np.array(input_mask)
|
713 |
+
if original_mask is not None:
|
714 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
715 |
+
original_mask = np.array(original_mask)
|
716 |
+
|
717 |
+
if input_mask.max() == 0:
|
718 |
+
original_mask = original_mask
|
719 |
+
else:
|
720 |
+
original_mask = input_mask
|
721 |
+
|
722 |
+
if original_mask is None:
|
723 |
+
raise gr.Error('Please generate mask first')
|
724 |
+
|
725 |
+
if original_mask.ndim == 2:
|
726 |
+
original_mask = original_mask[:,:,None]
|
727 |
+
|
728 |
+
dilation_type = np.random.choice(['square_erosion'])
|
729 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
730 |
+
|
731 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
732 |
+
|
733 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
734 |
+
masked_image = masked_image.astype(original_image.dtype)
|
735 |
+
masked_image = Image.fromarray(masked_image)
|
736 |
+
|
737 |
+
|
738 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
739 |
+
|
740 |
+
|
741 |
+
def move_mask_left(input_image,
|
742 |
+
original_image,
|
743 |
+
original_mask,
|
744 |
+
moving_pixels,
|
745 |
+
resize_default,
|
746 |
+
aspect_ratio_name):
|
747 |
+
|
748 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
749 |
+
input_mask = np.asarray(alpha_mask)
|
750 |
+
|
751 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
752 |
+
if output_w == "" or output_h == "":
|
753 |
+
output_h, output_w = original_image.shape[:2]
|
754 |
+
if resize_default:
|
755 |
+
short_side = min(output_w, output_h)
|
756 |
+
scale_ratio = 640 / short_side
|
757 |
+
output_w = int(output_w * scale_ratio)
|
758 |
+
output_h = int(output_h * scale_ratio)
|
759 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
760 |
+
original_image = np.array(original_image)
|
761 |
+
if input_mask is not None:
|
762 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
763 |
+
input_mask = np.array(input_mask)
|
764 |
+
if original_mask is not None:
|
765 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
766 |
+
original_mask = np.array(original_mask)
|
767 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
768 |
+
else:
|
769 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
770 |
+
pass
|
771 |
+
else:
|
772 |
+
if resize_default:
|
773 |
+
short_side = min(output_w, output_h)
|
774 |
+
scale_ratio = 640 / short_side
|
775 |
+
output_w = int(output_w * scale_ratio)
|
776 |
+
output_h = int(output_h * scale_ratio)
|
777 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
778 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
779 |
+
original_image = np.array(original_image)
|
780 |
+
if input_mask is not None:
|
781 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
782 |
+
input_mask = np.array(input_mask)
|
783 |
+
if original_mask is not None:
|
784 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
785 |
+
original_mask = np.array(original_mask)
|
786 |
+
|
787 |
+
if input_mask.max() == 0:
|
788 |
+
original_mask = original_mask
|
789 |
+
else:
|
790 |
+
original_mask = input_mask
|
791 |
+
|
792 |
+
if original_mask is None:
|
793 |
+
raise gr.Error('Please generate mask first')
|
794 |
+
|
795 |
+
if original_mask.ndim == 2:
|
796 |
+
original_mask = original_mask[:,:,None]
|
797 |
+
|
798 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
799 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
800 |
+
|
801 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
802 |
+
masked_image = masked_image.astype(original_image.dtype)
|
803 |
+
masked_image = Image.fromarray(masked_image)
|
804 |
+
|
805 |
+
if moved_mask.max() <= 1:
|
806 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
807 |
+
original_mask = moved_mask
|
808 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
809 |
+
|
810 |
+
|
811 |
+
def move_mask_right(input_image,
|
812 |
+
original_image,
|
813 |
+
original_mask,
|
814 |
+
moving_pixels,
|
815 |
+
resize_default,
|
816 |
+
aspect_ratio_name):
|
817 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
818 |
+
input_mask = np.asarray(alpha_mask)
|
819 |
+
|
820 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
821 |
+
if output_w == "" or output_h == "":
|
822 |
+
output_h, output_w = original_image.shape[:2]
|
823 |
+
if resize_default:
|
824 |
+
short_side = min(output_w, output_h)
|
825 |
+
scale_ratio = 640 / short_side
|
826 |
+
output_w = int(output_w * scale_ratio)
|
827 |
+
output_h = int(output_h * scale_ratio)
|
828 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
829 |
+
original_image = np.array(original_image)
|
830 |
+
if input_mask is not None:
|
831 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
832 |
+
input_mask = np.array(input_mask)
|
833 |
+
if original_mask is not None:
|
834 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
835 |
+
original_mask = np.array(original_mask)
|
836 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
837 |
+
else:
|
838 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
839 |
+
pass
|
840 |
+
else:
|
841 |
+
if resize_default:
|
842 |
+
short_side = min(output_w, output_h)
|
843 |
+
scale_ratio = 640 / short_side
|
844 |
+
output_w = int(output_w * scale_ratio)
|
845 |
+
output_h = int(output_h * scale_ratio)
|
846 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
847 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
848 |
+
original_image = np.array(original_image)
|
849 |
+
if input_mask is not None:
|
850 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
851 |
+
input_mask = np.array(input_mask)
|
852 |
+
if original_mask is not None:
|
853 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
854 |
+
original_mask = np.array(original_mask)
|
855 |
+
|
856 |
+
if input_mask.max() == 0:
|
857 |
+
original_mask = original_mask
|
858 |
+
else:
|
859 |
+
original_mask = input_mask
|
860 |
+
|
861 |
+
if original_mask is None:
|
862 |
+
raise gr.Error('Please generate mask first')
|
863 |
+
|
864 |
+
if original_mask.ndim == 2:
|
865 |
+
original_mask = original_mask[:,:,None]
|
866 |
+
|
867 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
868 |
+
|
869 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
870 |
+
|
871 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
872 |
+
masked_image = masked_image.astype(original_image.dtype)
|
873 |
+
masked_image = Image.fromarray(masked_image)
|
874 |
+
|
875 |
+
|
876 |
+
if moved_mask.max() <= 1:
|
877 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
878 |
+
original_mask = moved_mask
|
879 |
+
|
880 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
881 |
+
|
882 |
+
|
883 |
+
def move_mask_up(input_image,
|
884 |
+
original_image,
|
885 |
+
original_mask,
|
886 |
+
moving_pixels,
|
887 |
+
resize_default,
|
888 |
+
aspect_ratio_name):
|
889 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
890 |
+
input_mask = np.asarray(alpha_mask)
|
891 |
+
|
892 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
893 |
+
if output_w == "" or output_h == "":
|
894 |
+
output_h, output_w = original_image.shape[:2]
|
895 |
+
if resize_default:
|
896 |
+
short_side = min(output_w, output_h)
|
897 |
+
scale_ratio = 640 / short_side
|
898 |
+
output_w = int(output_w * scale_ratio)
|
899 |
+
output_h = int(output_h * scale_ratio)
|
900 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
901 |
+
original_image = np.array(original_image)
|
902 |
+
if input_mask is not None:
|
903 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
904 |
+
input_mask = np.array(input_mask)
|
905 |
+
if original_mask is not None:
|
906 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
907 |
+
original_mask = np.array(original_mask)
|
908 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
909 |
+
else:
|
910 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
911 |
+
pass
|
912 |
+
else:
|
913 |
+
if resize_default:
|
914 |
+
short_side = min(output_w, output_h)
|
915 |
+
scale_ratio = 640 / short_side
|
916 |
+
output_w = int(output_w * scale_ratio)
|
917 |
+
output_h = int(output_h * scale_ratio)
|
918 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
919 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
920 |
+
original_image = np.array(original_image)
|
921 |
+
if input_mask is not None:
|
922 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
923 |
+
input_mask = np.array(input_mask)
|
924 |
+
if original_mask is not None:
|
925 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
926 |
+
original_mask = np.array(original_mask)
|
927 |
+
|
928 |
+
if input_mask.max() == 0:
|
929 |
+
original_mask = original_mask
|
930 |
+
else:
|
931 |
+
original_mask = input_mask
|
932 |
+
|
933 |
+
if original_mask is None:
|
934 |
+
raise gr.Error('Please generate mask first')
|
935 |
+
|
936 |
+
if original_mask.ndim == 2:
|
937 |
+
original_mask = original_mask[:,:,None]
|
938 |
+
|
939 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
940 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
941 |
+
|
942 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
943 |
+
masked_image = masked_image.astype(original_image.dtype)
|
944 |
+
masked_image = Image.fromarray(masked_image)
|
945 |
+
|
946 |
+
if moved_mask.max() <= 1:
|
947 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
948 |
+
original_mask = moved_mask
|
949 |
+
|
950 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
951 |
+
|
952 |
+
|
953 |
+
def move_mask_down(input_image,
|
954 |
+
original_image,
|
955 |
+
original_mask,
|
956 |
+
moving_pixels,
|
957 |
+
resize_default,
|
958 |
+
aspect_ratio_name):
|
959 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
960 |
+
input_mask = np.asarray(alpha_mask)
|
961 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
962 |
+
if output_w == "" or output_h == "":
|
963 |
+
output_h, output_w = original_image.shape[:2]
|
964 |
+
if resize_default:
|
965 |
+
short_side = min(output_w, output_h)
|
966 |
+
scale_ratio = 640 / short_side
|
967 |
+
output_w = int(output_w * scale_ratio)
|
968 |
+
output_h = int(output_h * scale_ratio)
|
969 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
970 |
+
original_image = np.array(original_image)
|
971 |
+
if input_mask is not None:
|
972 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
973 |
+
input_mask = np.array(input_mask)
|
974 |
+
if original_mask is not None:
|
975 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
976 |
+
original_mask = np.array(original_mask)
|
977 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
978 |
+
else:
|
979 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
980 |
+
pass
|
981 |
+
else:
|
982 |
+
if resize_default:
|
983 |
+
short_side = min(output_w, output_h)
|
984 |
+
scale_ratio = 640 / short_side
|
985 |
+
output_w = int(output_w * scale_ratio)
|
986 |
+
output_h = int(output_h * scale_ratio)
|
987 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
988 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
989 |
+
original_image = np.array(original_image)
|
990 |
+
if input_mask is not None:
|
991 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
992 |
+
input_mask = np.array(input_mask)
|
993 |
+
if original_mask is not None:
|
994 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
995 |
+
original_mask = np.array(original_mask)
|
996 |
+
|
997 |
+
if input_mask.max() == 0:
|
998 |
+
original_mask = original_mask
|
999 |
+
else:
|
1000 |
+
original_mask = input_mask
|
1001 |
+
|
1002 |
+
if original_mask is None:
|
1003 |
+
raise gr.Error('Please generate mask first')
|
1004 |
+
|
1005 |
+
if original_mask.ndim == 2:
|
1006 |
+
original_mask = original_mask[:,:,None]
|
1007 |
+
|
1008 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1009 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1010 |
+
|
1011 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1012 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1013 |
+
masked_image = Image.fromarray(masked_image)
|
1014 |
+
|
1015 |
+
if moved_mask.max() <= 1:
|
1016 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1017 |
+
original_mask = moved_mask
|
1018 |
+
|
1019 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1020 |
+
|
1021 |
+
|
1022 |
+
def invert_mask(input_image,
|
1023 |
+
original_image,
|
1024 |
+
original_mask,
|
1025 |
+
):
|
1026 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1027 |
+
input_mask = np.asarray(alpha_mask)
|
1028 |
+
if input_mask.max() == 0:
|
1029 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1030 |
+
else:
|
1031 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1032 |
+
|
1033 |
+
if original_mask is None:
|
1034 |
+
raise gr.Error('Please generate mask first')
|
1035 |
+
|
1036 |
+
original_mask = original_mask.squeeze()
|
1037 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1038 |
+
|
1039 |
+
if original_mask.ndim == 2:
|
1040 |
+
original_mask = original_mask[:,:,None]
|
1041 |
+
|
1042 |
+
if original_mask.max() <= 1:
|
1043 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1044 |
+
|
1045 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1046 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1047 |
+
masked_image = Image.fromarray(masked_image)
|
1048 |
+
|
1049 |
+
return [masked_image], [mask_image], original_mask, True
|
1050 |
+
|
1051 |
+
|
1052 |
+
def init_img(base,
|
1053 |
+
init_type,
|
1054 |
+
prompt,
|
1055 |
+
aspect_ratio,
|
1056 |
+
example_change_times
|
1057 |
+
):
|
1058 |
+
image_pil = base["background"].convert("RGB")
|
1059 |
+
original_image = np.array(image_pil)
|
1060 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1061 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1062 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1063 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1064 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1065 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1066 |
+
width, height = image_pil.size
|
1067 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1068 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1069 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1070 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1071 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1072 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1073 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1074 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1075 |
+
else:
|
1076 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1077 |
+
aspect_ratio = "Custom resolution"
|
1078 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1079 |
+
|
1080 |
+
|
1081 |
+
def reset_func(input_image,
|
1082 |
+
original_image,
|
1083 |
+
original_mask,
|
1084 |
+
prompt,
|
1085 |
+
target_prompt,
|
1086 |
+
):
|
1087 |
+
input_image = None
|
1088 |
+
original_image = None
|
1089 |
+
original_mask = None
|
1090 |
+
prompt = ''
|
1091 |
+
mask_gallery = []
|
1092 |
+
masked_gallery = []
|
1093 |
+
result_gallery = []
|
1094 |
+
target_prompt = ''
|
1095 |
+
if torch.cuda.is_available():
|
1096 |
+
torch.cuda.empty_cache()
|
1097 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1098 |
+
|
1099 |
+
|
1100 |
+
def update_example(example_type,
|
1101 |
+
prompt,
|
1102 |
+
example_change_times):
|
1103 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1104 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1105 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1106 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1107 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1108 |
+
width, height = image_pil.size
|
1109 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1110 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1111 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1112 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1113 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1114 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1115 |
+
|
1116 |
+
original_image = np.array(image_pil)
|
1117 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1118 |
+
aspect_ratio = "Custom resolution"
|
1119 |
+
example_change_times += 1
|
1120 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1121 |
+
|
1122 |
+
|
1123 |
+
def generate_target_prompt(input_image,
|
1124 |
+
original_image,
|
1125 |
+
prompt):
|
1126 |
+
# load example image
|
1127 |
+
if isinstance(original_image, str):
|
1128 |
+
original_image = input_image
|
1129 |
+
|
1130 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1131 |
+
vlm_processor,
|
1132 |
+
vlm_model,
|
1133 |
+
original_image,
|
1134 |
+
prompt,
|
1135 |
+
device)
|
1136 |
+
return prompt_after_apply_instruction
|
1137 |
+
|
1138 |
+
|
1139 |
+
|
1140 |
+
|
1141 |
+
def process_mask(input_image,
|
1142 |
+
original_image,
|
1143 |
+
prompt,
|
1144 |
+
resize_default,
|
1145 |
+
aspect_ratio_name):
|
1146 |
+
if original_image is None:
|
1147 |
+
raise gr.Error('Please upload the input image')
|
1148 |
+
if prompt is None:
|
1149 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
1150 |
+
|
1151 |
+
## load mask
|
1152 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1153 |
+
input_mask = np.array(alpha_mask)
|
1154 |
+
|
1155 |
+
# load example image
|
1156 |
+
if isinstance(original_image, str):
|
1157 |
+
original_image = input_image["background"]
|
1158 |
+
|
1159 |
+
if input_mask.max() == 0:
|
1160 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
1161 |
+
|
1162 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
1163 |
+
vlm_model,
|
1164 |
+
original_image,
|
1165 |
+
category,
|
1166 |
+
prompt,
|
1167 |
+
device)
|
1168 |
+
# original mask: h,w,1 [0, 255]
|
1169 |
+
original_mask = vlm_response_mask(
|
1170 |
+
vlm_processor,
|
1171 |
+
vlm_model,
|
1172 |
+
category,
|
1173 |
+
original_image,
|
1174 |
+
prompt,
|
1175 |
+
object_wait_for_edit,
|
1176 |
+
sam,
|
1177 |
+
sam_predictor,
|
1178 |
+
sam_automask_generator,
|
1179 |
+
groundingdino_model,
|
1180 |
+
device).astype(np.uint8)
|
1181 |
+
else:
|
1182 |
+
original_mask = input_mask.astype(np.uint8)
|
1183 |
+
category = None
|
1184 |
+
|
1185 |
+
## resize mask if needed
|
1186 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1187 |
+
if output_w == "" or output_h == "":
|
1188 |
+
output_h, output_w = original_image.shape[:2]
|
1189 |
+
if resize_default:
|
1190 |
+
short_side = min(output_w, output_h)
|
1191 |
+
scale_ratio = 640 / short_side
|
1192 |
+
output_w = int(output_w * scale_ratio)
|
1193 |
+
output_h = int(output_h * scale_ratio)
|
1194 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1195 |
+
original_image = np.array(original_image)
|
1196 |
+
if input_mask is not None:
|
1197 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1198 |
+
input_mask = np.array(input_mask)
|
1199 |
+
if original_mask is not None:
|
1200 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1201 |
+
original_mask = np.array(original_mask)
|
1202 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1203 |
+
else:
|
1204 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1205 |
+
pass
|
1206 |
+
else:
|
1207 |
+
if resize_default:
|
1208 |
+
short_side = min(output_w, output_h)
|
1209 |
+
scale_ratio = 640 / short_side
|
1210 |
+
output_w = int(output_w * scale_ratio)
|
1211 |
+
output_h = int(output_h * scale_ratio)
|
1212 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1213 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1214 |
+
original_image = np.array(original_image)
|
1215 |
+
if input_mask is not None:
|
1216 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1217 |
+
input_mask = np.array(input_mask)
|
1218 |
+
if original_mask is not None:
|
1219 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1220 |
+
original_mask = np.array(original_mask)
|
1221 |
+
|
1222 |
+
|
1223 |
+
if original_mask.ndim == 2:
|
1224 |
+
original_mask = original_mask[:,:,None]
|
1225 |
+
|
1226 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
1227 |
+
|
1228 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1229 |
+
masked_image = masked_image.astype(np.uint8)
|
1230 |
+
masked_image = Image.fromarray(masked_image)
|
1231 |
+
|
1232 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
1233 |
+
|
1234 |
+
|
1235 |
+
|
1236 |
+
def process(input_image,
|
1237 |
+
original_image,
|
1238 |
+
original_mask,
|
1239 |
+
prompt,
|
1240 |
+
negative_prompt,
|
1241 |
+
control_strength,
|
1242 |
+
seed,
|
1243 |
+
randomize_seed,
|
1244 |
+
guidance_scale,
|
1245 |
+
num_inference_steps,
|
1246 |
+
num_samples,
|
1247 |
+
blending,
|
1248 |
+
category,
|
1249 |
+
target_prompt,
|
1250 |
+
resize_default,
|
1251 |
+
aspect_ratio_name,
|
1252 |
+
invert_mask_state):
|
1253 |
+
if original_image is None:
|
1254 |
+
if input_image is None:
|
1255 |
+
raise gr.Error('Please upload the input image')
|
1256 |
+
else:
|
1257 |
+
image_pil = input_image["background"].convert("RGB")
|
1258 |
+
original_image = np.array(image_pil)
|
1259 |
+
if prompt is None or prompt == "":
|
1260 |
+
if target_prompt is None or target_prompt == "":
|
1261 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
1262 |
+
|
1263 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1264 |
+
input_mask = np.asarray(alpha_mask)
|
1265 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1266 |
+
if output_w == "" or output_h == "":
|
1267 |
+
output_h, output_w = original_image.shape[:2]
|
1268 |
+
|
1269 |
+
if resize_default:
|
1270 |
+
short_side = min(output_w, output_h)
|
1271 |
+
scale_ratio = 640 / short_side
|
1272 |
+
output_w = int(output_w * scale_ratio)
|
1273 |
+
output_h = int(output_h * scale_ratio)
|
1274 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1275 |
+
original_image = np.array(original_image)
|
1276 |
+
if input_mask is not None:
|
1277 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1278 |
+
input_mask = np.array(input_mask)
|
1279 |
+
if original_mask is not None:
|
1280 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1281 |
+
original_mask = np.array(original_mask)
|
1282 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1283 |
+
else:
|
1284 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1285 |
+
pass
|
1286 |
+
else:
|
1287 |
+
if resize_default:
|
1288 |
+
short_side = min(output_w, output_h)
|
1289 |
+
scale_ratio = 640 / short_side
|
1290 |
+
output_w = int(output_w * scale_ratio)
|
1291 |
+
output_h = int(output_h * scale_ratio)
|
1292 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1293 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1294 |
+
original_image = np.array(original_image)
|
1295 |
+
if input_mask is not None:
|
1296 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1297 |
+
input_mask = np.array(input_mask)
|
1298 |
+
if original_mask is not None:
|
1299 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1300 |
+
original_mask = np.array(original_mask)
|
1301 |
+
|
1302 |
+
if invert_mask_state:
|
1303 |
+
original_mask = original_mask
|
1304 |
+
else:
|
1305 |
+
if input_mask.max() == 0:
|
1306 |
+
original_mask = original_mask
|
1307 |
+
else:
|
1308 |
+
original_mask = input_mask
|
1309 |
+
|
1310 |
+
|
1311 |
+
# inpainting directly if target_prompt is not None
|
1312 |
+
if category is not None:
|
1313 |
+
pass
|
1314 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
1315 |
+
pass
|
1316 |
+
else:
|
1317 |
+
try:
|
1318 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
1319 |
+
except Exception as e:
|
1320 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1321 |
+
|
1322 |
+
|
1323 |
+
if original_mask is not None:
|
1324 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
1325 |
+
else:
|
1326 |
+
try:
|
1327 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
1328 |
+
vlm_processor,
|
1329 |
+
vlm_model,
|
1330 |
+
original_image,
|
1331 |
+
category,
|
1332 |
+
prompt,
|
1333 |
+
device)
|
1334 |
+
|
1335 |
+
original_mask = vlm_response_mask(vlm_processor,
|
1336 |
+
vlm_model,
|
1337 |
+
category,
|
1338 |
+
original_image,
|
1339 |
+
prompt,
|
1340 |
+
object_wait_for_edit,
|
1341 |
+
sam,
|
1342 |
+
sam_predictor,
|
1343 |
+
sam_automask_generator,
|
1344 |
+
groundingdino_model,
|
1345 |
+
device).astype(np.uint8)
|
1346 |
+
except Exception as e:
|
1347 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1348 |
+
|
1349 |
+
if original_mask.ndim == 2:
|
1350 |
+
original_mask = original_mask[:,:,None]
|
1351 |
+
|
1352 |
+
|
1353 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
1354 |
+
prompt_after_apply_instruction = target_prompt
|
1355 |
+
|
1356 |
+
else:
|
1357 |
+
try:
|
1358 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1359 |
+
vlm_processor,
|
1360 |
+
vlm_model,
|
1361 |
+
original_image,
|
1362 |
+
prompt,
|
1363 |
+
device)
|
1364 |
+
except Exception as e:
|
1365 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
1366 |
+
|
1367 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
1368 |
+
|
1369 |
+
|
1370 |
+
with torch.autocast(device):
|
1371 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
1372 |
+
prompt_after_apply_instruction,
|
1373 |
+
original_mask,
|
1374 |
+
original_image,
|
1375 |
+
generator,
|
1376 |
+
num_inference_steps,
|
1377 |
+
guidance_scale,
|
1378 |
+
control_strength,
|
1379 |
+
negative_prompt,
|
1380 |
+
num_samples,
|
1381 |
+
blending)
|
1382 |
+
original_image = np.array(init_image_np)
|
1383 |
+
masked_image = original_image * (1 - (mask_np>0))
|
1384 |
+
masked_image = masked_image.astype(np.uint8)
|
1385 |
+
masked_image = Image.fromarray(masked_image)
|
1386 |
+
# Save the images (optional)
|
1387 |
+
# import uuid
|
1388 |
+
# uuid = str(uuid.uuid4())
|
1389 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
1390 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
1391 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
1392 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
1393 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
1394 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
1395 |
+
# gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
1396 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
1397 |
+
|
1398 |
+
|
1399 |
+
# 新增事件处理函数
|
1400 |
+
def generate_blip_description(input_image):
|
1401 |
+
if input_image is None:
|
1402 |
+
return "", "Input image cannot be None"
|
1403 |
+
try:
|
1404 |
+
image_pil = input_image["background"].convert("RGB")
|
1405 |
+
except KeyError:
|
1406 |
+
return "", "Input image missing 'background' key"
|
1407 |
+
except AttributeError as e:
|
1408 |
+
return "", f"Invalid image object: {str(e)}"
|
1409 |
+
try:
|
1410 |
+
description = generate_caption(blip_processor, blip_model, image_pil, device)
|
1411 |
+
return description, description # 同时更新state和显示组件
|
1412 |
+
except Exception as e:
|
1413 |
+
return "", f"Caption generation failed: {str(e)}"
|
1414 |
+
|
1415 |
+
from app.utils.utils import generate_caption
|
1416 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
1417 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1418 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
|
1419 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
|
1420 |
+
|
1421 |
+
|
1422 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
1423 |
+
global vlm_model, vlm_processor
|
1424 |
+
if vlm_model is not None:
|
1425 |
+
del vlm_model
|
1426 |
+
torch.cuda.empty_cache()
|
1427 |
+
try:
|
1428 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
|
1429 |
+
vlm_processor = ""
|
1430 |
+
response = vlm_model.chat.completions.create(
|
1431 |
+
model="deepseek-chat",
|
1432 |
+
messages=[
|
1433 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1434 |
+
{"role": "user", "content": "Hello."}
|
1435 |
+
]
|
1436 |
+
)
|
1437 |
+
response_str = response.choices[0].message.content
|
1438 |
+
|
1439 |
+
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
|
1440 |
+
except Exception as e:
|
1441 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
1442 |
+
|
1443 |
+
|
1444 |
+
def verify_deepseek_api():
|
1445 |
+
try:
|
1446 |
+
response = llm_model.chat.completions.create(
|
1447 |
+
model="deepseek-chat",
|
1448 |
+
messages=[
|
1449 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1450 |
+
{"role": "user", "content": "Hello."}
|
1451 |
+
]
|
1452 |
+
)
|
1453 |
+
response_str = response.choices[0].message.content
|
1454 |
+
|
1455 |
+
return True, "Success. " + response_str
|
1456 |
+
|
1457 |
+
except Exception as e:
|
1458 |
+
return False, "Invalid DeepSeek API Key"
|
1459 |
+
|
1460 |
+
|
1461 |
+
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
|
1462 |
+
try:
|
1463 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
1464 |
+
response = llm_model.chat.completions.create(
|
1465 |
+
model="deepseek-chat",
|
1466 |
+
messages=messages
|
1467 |
+
)
|
1468 |
+
response_str = response.choices[0].message.content
|
1469 |
+
return response_str
|
1470 |
+
except Exception as e:
|
1471 |
+
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1472 |
+
|
1473 |
+
|
1474 |
+
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
|
1475 |
+
try:
|
1476 |
+
messages = create_decomposed_query_messages_deepseek(integrated_query)
|
1477 |
+
response = llm_model.chat.completions.create(
|
1478 |
+
model="deepseek-chat",
|
1479 |
+
messages=messages
|
1480 |
+
)
|
1481 |
+
response_str = response.choices[0].message.content
|
1482 |
+
return response_str
|
1483 |
+
except Exception as e:
|
1484 |
+
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1485 |
+
|
1486 |
+
|
1487 |
+
def enhance_description(blip_description, prompt):
|
1488 |
+
try:
|
1489 |
+
if not prompt or not blip_description:
|
1490 |
+
print("Empty prompt or blip_description detected")
|
1491 |
+
return "", ""
|
1492 |
+
|
1493 |
+
print(f"Enhancing with prompt: {prompt}")
|
1494 |
+
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
|
1495 |
+
return enhanced_description, enhanced_description
|
1496 |
+
|
1497 |
+
except Exception as e:
|
1498 |
+
print(f"Enhancement failed: {str(e)}")
|
1499 |
+
return "Error occurred", "Error occurred"
|
1500 |
+
|
1501 |
+
|
1502 |
+
def decompose_description(enhanced_description):
|
1503 |
+
try:
|
1504 |
+
if not enhanced_description:
|
1505 |
+
print("Empty enhanced_description detected")
|
1506 |
+
return "", ""
|
1507 |
+
|
1508 |
+
print(f"Decomposing the enhanced description: {enhanced_description}")
|
1509 |
+
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
|
1510 |
+
return decomposed_description, decomposed_description
|
1511 |
+
|
1512 |
+
except Exception as e:
|
1513 |
+
print(f"Decomposition failed: {str(e)}")
|
1514 |
+
return "Error occurred", "Error occurred"
|
1515 |
+
|
1516 |
+
|
1517 |
+
@torch.no_grad()
|
1518 |
+
def mix_and_search(enhanced_text: str, gallery_images: list):
|
1519 |
+
# 获取最新生成的图像元组
|
1520 |
+
latest_item = gallery_images[-1] if gallery_images else None
|
1521 |
+
|
1522 |
+
# 初始化特征列表
|
1523 |
+
features = []
|
1524 |
+
|
1525 |
+
# 图像特征提取
|
1526 |
+
if latest_item and isinstance(latest_item, tuple):
|
1527 |
+
try:
|
1528 |
+
image_path = latest_item[0]
|
1529 |
+
pil_image = Image.open(image_path).convert("RGB")
|
1530 |
+
|
1531 |
+
# 使用 CLIPProcessor 处理图像
|
1532 |
+
image_inputs = clip_processor(
|
1533 |
+
images=pil_image,
|
1534 |
+
return_tensors="pt"
|
1535 |
+
).to(device)
|
1536 |
+
|
1537 |
+
image_features = clip_model.get_image_features(**image_inputs)
|
1538 |
+
features.append(F.normalize(image_features, dim=-1))
|
1539 |
+
except Exception as e:
|
1540 |
+
print(f"图像处理失败: {str(e)}")
|
1541 |
+
|
1542 |
+
# 文本特征提取
|
1543 |
+
if enhanced_text.strip():
|
1544 |
+
text_inputs = clip_processor(
|
1545 |
+
text=enhanced_text,
|
1546 |
+
return_tensors="pt",
|
1547 |
+
padding=True,
|
1548 |
+
truncation=True
|
1549 |
+
).to(device)
|
1550 |
+
|
1551 |
+
text_features = clip_model.get_text_features(**text_inputs)
|
1552 |
+
features.append(F.normalize(text_features, dim=-1))
|
1553 |
+
|
1554 |
+
if not features:
|
1555 |
+
return "## 错误:请先完成图像编辑并生成描述", []
|
1556 |
+
|
1557 |
+
# 特征融合与检索
|
1558 |
+
mixed = sum(features) / len(features)
|
1559 |
+
mixed = F.normalize(mixed, dim=-1)
|
1560 |
+
|
1561 |
+
# 加载Faiss索引和图片路径映射
|
1562 |
+
index_path = "/home/zt/data/open-images/train/knn.index"
|
1563 |
+
input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata")
|
1564 |
+
base_image_dir = Path("/home/zt/data/open-images/train/")
|
1565 |
+
|
1566 |
+
# 按文件名中的数字排序并直接读取parquet文件
|
1567 |
+
parquet_files = sorted(
|
1568 |
+
input_data_dir.glob('*.parquet'),
|
1569 |
+
key=lambda x: int(x.stem.split("_")[-1])
|
1570 |
+
)
|
1571 |
+
|
1572 |
+
# 合并所有parquet数据
|
1573 |
+
dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
|
1574 |
+
df = pd.concat(dfs, ignore_index=True)
|
1575 |
+
image_paths = df["image_path"].tolist()
|
1576 |
+
|
1577 |
+
# 读取Faiss索引
|
1578 |
+
index = faiss.read_index(index_path)
|
1579 |
+
assert mixed.shape[1] == index.d, "特征维度不匹配"
|
1580 |
+
|
1581 |
+
# 执行检索
|
1582 |
+
mixed = mixed.cpu().detach().numpy().astype('float32')
|
1583 |
+
distances, indices = index.search(mixed, 5)
|
1584 |
+
|
1585 |
+
# 获取并验证图片路径
|
1586 |
+
retrieved_images = []
|
1587 |
+
for idx in indices[0]:
|
1588 |
+
if 0 <= idx < len(image_paths):
|
1589 |
+
img_path = base_image_dir / image_paths[idx]
|
1590 |
+
try:
|
1591 |
+
if img_path.exists():
|
1592 |
+
retrieved_images.append(Image.open(img_path).convert("RGB"))
|
1593 |
+
else:
|
1594 |
+
print(f"警告:文件缺失 {img_path}")
|
1595 |
+
except Exception as e:
|
1596 |
+
print(f"图片加载失败: {str(e)}")
|
1597 |
+
|
1598 |
+
return "## 检索到以下相似图片:", retrieved_images if retrieved_images else ("## 未找到匹配的图片", [])
|
1599 |
+
|
1600 |
+
|
1601 |
+
block = gr.Blocks(
|
1602 |
+
theme=gr.themes.Soft(
|
1603 |
+
radius_size=gr.themes.sizes.radius_none,
|
1604 |
+
text_size=gr.themes.sizes.text_md
|
1605 |
+
)
|
1606 |
+
)
|
1607 |
+
with block as demo:
|
1608 |
+
with gr.Row():
|
1609 |
+
with gr.Column():
|
1610 |
+
gr.HTML(head)
|
1611 |
+
|
1612 |
+
gr.Markdown(descriptions)
|
1613 |
+
|
1614 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1615 |
+
with gr.Row(equal_height=True):
|
1616 |
+
gr.Markdown(instructions)
|
1617 |
+
|
1618 |
+
original_image = gr.State(value=None)
|
1619 |
+
original_mask = gr.State(value=None)
|
1620 |
+
category = gr.State(value=None)
|
1621 |
+
status = gr.State(value=None)
|
1622 |
+
invert_mask_state = gr.State(value=False)
|
1623 |
+
example_change_times = gr.State(value=0)
|
1624 |
+
deepseek_verified = gr.State(value=False)
|
1625 |
+
blip_description = gr.State(value="")
|
1626 |
+
enhanced_description = gr.State(value="")
|
1627 |
+
decomposed_description = gr.State(value="")
|
1628 |
+
|
1629 |
+
with gr.Row():
|
1630 |
+
with gr.Column():
|
1631 |
+
with gr.Row():
|
1632 |
+
input_image = gr.ImageEditor(
|
1633 |
+
label="参考图像",
|
1634 |
+
type="pil",
|
1635 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1636 |
+
layers = False,
|
1637 |
+
interactive=True,
|
1638 |
+
# height=1024,
|
1639 |
+
height=512,
|
1640 |
+
sources=["upload"],
|
1641 |
+
placeholder="🫧 点击此处或下面的图标上传图像 🫧",
|
1642 |
+
)
|
1643 |
+
|
1644 |
+
prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
|
1645 |
+
run_button = gr.Button("💫 图像编辑")
|
1646 |
+
|
1647 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1648 |
+
with gr.Group():
|
1649 |
+
with gr.Row():
|
1650 |
+
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1651 |
+
GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1652 |
+
GPT4o_KEY_submit = gr.Button("🙈 验证")
|
1653 |
+
|
1654 |
+
|
1655 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1656 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1657 |
+
|
1658 |
+
with gr.Row():
|
1659 |
+
mask_button = gr.Button("💎 掩膜生成")
|
1660 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1661 |
+
|
1662 |
+
# 在分解按钮后添加
|
1663 |
+
with gr.Group():
|
1664 |
+
with gr.Row():
|
1665 |
+
retrieve_button = gr.Button("🔍 开始检索")
|
1666 |
+
with gr.Row():
|
1667 |
+
retrieve_output = gr.Markdown(elem_id="accordion")
|
1668 |
+
with gr.Row():
|
1669 |
+
retrieve_gallery = gr.Gallery(label="🎊 检索结果",show_label=True, elem_id="gallery", preview=True, height=400) # 新增Gallery组件
|
1670 |
+
|
1671 |
+
with gr.Row():
|
1672 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1673 |
+
|
1674 |
+
target_prompt = gr.Text(
|
1675 |
+
label="Input Target Prompt",
|
1676 |
+
max_lines=5,
|
1677 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1678 |
+
value='',
|
1679 |
+
lines=2
|
1680 |
+
)
|
1681 |
+
|
1682 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1683 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1684 |
+
negative_prompt = gr.Text(
|
1685 |
+
label="Negative Prompt",
|
1686 |
+
max_lines=5,
|
1687 |
+
placeholder="Please input your negative prompt",
|
1688 |
+
value='ugly, low quality',lines=1
|
1689 |
+
)
|
1690 |
+
|
1691 |
+
control_strength = gr.Slider(
|
1692 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1693 |
+
)
|
1694 |
+
with gr.Group():
|
1695 |
+
seed = gr.Slider(
|
1696 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1697 |
+
)
|
1698 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1699 |
+
|
1700 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1701 |
+
|
1702 |
+
|
1703 |
+
num_samples = gr.Slider(
|
1704 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1705 |
+
)
|
1706 |
+
|
1707 |
+
with gr.Group():
|
1708 |
+
with gr.Row():
|
1709 |
+
guidance_scale = gr.Slider(
|
1710 |
+
label="Guidance scale",
|
1711 |
+
minimum=1,
|
1712 |
+
maximum=12,
|
1713 |
+
step=0.1,
|
1714 |
+
value=7.5,
|
1715 |
+
)
|
1716 |
+
num_inference_steps = gr.Slider(
|
1717 |
+
label="Number of inference steps",
|
1718 |
+
minimum=1,
|
1719 |
+
maximum=50,
|
1720 |
+
step=1,
|
1721 |
+
value=50,
|
1722 |
+
)
|
1723 |
+
|
1724 |
+
|
1725 |
+
with gr.Group(visible=True):
|
1726 |
+
# BLIP生成的描述
|
1727 |
+
blip_output = gr.Textbox(label="原图描述", placeholder="💭 BLIP生成的图像基础描述 💭", interactive=True, lines=1)
|
1728 |
+
# DeepSeek API验证
|
1729 |
+
with gr.Row():
|
1730 |
+
deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1731 |
+
verify_deepseek = gr.Button("🙈 验证")
|
1732 |
+
# 整合后的描述区域
|
1733 |
+
with gr.Row():
|
1734 |
+
enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3)
|
1735 |
+
enhance_button = gr.Button("✨ 整合")
|
1736 |
+
# 分解后的描述区域
|
1737 |
+
with gr.Row():
|
1738 |
+
decomposed_output = gr.Textbox(label="描述分解", placeholder="💭 DeepSeek生成的分解描述 💭", interactive=True, lines=3)
|
1739 |
+
decompose_button = gr.Button("🔧 分解")
|
1740 |
+
with gr.Row():
|
1741 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1742 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1743 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1744 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1745 |
+
|
1746 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1747 |
+
dilation_size = gr.Slider(
|
1748 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1749 |
+
)
|
1750 |
+
with gr.Row():
|
1751 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1752 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1753 |
+
|
1754 |
+
moving_pixels = gr.Slider(
|
1755 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1756 |
+
)
|
1757 |
+
with gr.Row():
|
1758 |
+
move_left_button = gr.Button("Move Left")
|
1759 |
+
move_right_button = gr.Button("Move Right")
|
1760 |
+
with gr.Row():
|
1761 |
+
move_up_button = gr.Button("Move Up")
|
1762 |
+
move_down_button = gr.Button("Move Down")
|
1763 |
+
|
1764 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1765 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1766 |
+
|
1767 |
+
target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1768 |
+
|
1769 |
+
reset_button = gr.Button("Reset")
|
1770 |
+
|
1771 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1772 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1773 |
+
|
1774 |
+
|
1775 |
+
|
1776 |
+
with gr.Row():
|
1777 |
+
example = gr.Examples(
|
1778 |
+
label="Quick Example",
|
1779 |
+
examples=EXAMPLES,
|
1780 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1781 |
+
examples_per_page=10,
|
1782 |
+
cache_examples=False,
|
1783 |
+
)
|
1784 |
+
|
1785 |
+
|
1786 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1787 |
+
with gr.Row(equal_height=True):
|
1788 |
+
gr.Markdown(tips)
|
1789 |
+
|
1790 |
+
with gr.Row():
|
1791 |
+
gr.Markdown(citation)
|
1792 |
+
|
1793 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1794 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1795 |
+
input_image.upload(
|
1796 |
+
init_img,
|
1797 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1798 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1799 |
+
|
1800 |
+
)
|
1801 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1802 |
+
|
1803 |
+
|
1804 |
+
## vlm and base model dropdown
|
1805 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1806 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1807 |
+
|
1808 |
+
|
1809 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1810 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1811 |
+
|
1812 |
+
|
1813 |
+
ips=[input_image,
|
1814 |
+
original_image,
|
1815 |
+
original_mask,
|
1816 |
+
prompt,
|
1817 |
+
negative_prompt,
|
1818 |
+
control_strength,
|
1819 |
+
seed,
|
1820 |
+
randomize_seed,
|
1821 |
+
guidance_scale,
|
1822 |
+
num_inference_steps,
|
1823 |
+
num_samples,
|
1824 |
+
blending,
|
1825 |
+
category,
|
1826 |
+
target_prompt,
|
1827 |
+
resize_default,
|
1828 |
+
aspect_ratio,
|
1829 |
+
invert_mask_state]
|
1830 |
+
|
1831 |
+
## run brushedit
|
1832 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1833 |
+
|
1834 |
+
|
1835 |
+
## mask func
|
1836 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1837 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1838 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1839 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1840 |
+
|
1841 |
+
## reset func
|
1842 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1843 |
+
|
1844 |
+
|
1845 |
+
|
1846 |
+
# 绑定事件处理
|
1847 |
+
input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
|
1848 |
+
verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
|
1849 |
+
enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output])
|
1850 |
+
decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
|
1851 |
+
# 修改事件绑定
|
1852 |
+
retrieve_button.click(
|
1853 |
+
fn=mix_and_search,
|
1854 |
+
inputs=[enhanced_output, result_gallery],
|
1855 |
+
outputs=[retrieve_output, retrieve_gallery]
|
1856 |
+
)
|
1857 |
+
|
1858 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1859 |
+
|
1860 |
+
|
brushedit_app_new_jietu.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
brushedit_app_new_jietu2.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
brushedit_app_new_notqwen.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
brushedit_app_old.py
ADDED
@@ -0,0 +1,1702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
# base_openai_url = ""
|
39 |
+
|
40 |
+
#### Description ####
|
41 |
+
logo = r"""
|
42 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
43 |
+
"""
|
44 |
+
head = r"""
|
45 |
+
<div style="text-align: center;">
|
46 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
47 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
48 |
+
<a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
49 |
+
<a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
50 |
+
<a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
51 |
+
|
52 |
+
</div>
|
53 |
+
</br>
|
54 |
+
</div>
|
55 |
+
"""
|
56 |
+
descriptions = r"""
|
57 |
+
Official Gradio Demo for <a href='https://tencentarc.github.io/BrushNet/'><b>BrushEdit: All-In-One Image Inpainting and Editing</b></a><br>
|
58 |
+
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.<br>
|
59 |
+
"""
|
60 |
+
|
61 |
+
instructions = r"""
|
62 |
+
Currently, we support two modes: <b>fully automated command editing</b> and <b>interactive command editing</b>.
|
63 |
+
|
64 |
+
🛠️ <b>Fully automated instruction-based editing</b>:
|
65 |
+
<ul>
|
66 |
+
<li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
|
67 |
+
<li> ⭐️ <b>2.Input ⌨️ Instructions: </b> Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .</li>
|
68 |
+
<li> ⭐️ <b>3.Run: </b> Click <b>💫 Run</b> button to automatic edit image.</li>
|
69 |
+
</ul>
|
70 |
+
|
71 |
+
🛠️ <b>Interactive instruction-based editing</b>:
|
72 |
+
<ul>
|
73 |
+
<li> ⭐️ <b>1.Choose Image: </b> Upload <img src="https://github.com/user-attachments/assets/f2dca1e6-31f9-4716-ae84-907f24415bac" alt="upload" style="display:inline; height:1em; vertical-align:middle;"> or select <img src="https://github.com/user-attachments/assets/de808f7d-c74a-44c7-9cbf-f0dbfc2c1abf" alt="example" style="display:inline; height:1em; vertical-align:middle;"> one image from Example. </li>
|
74 |
+
<li> ⭐️ <b>2.Finely Brushing: </b> Use a brush <img src="https://github.com/user-attachments/assets/c466c5cc-ac8f-4b4a-9bc5-04c4737fe1ef" alt="brush" style="display:inline; height:1em; vertical-align:middle;"> to outline the area you want to edit. And You can also use the eraser <img src="https://github.com/user-attachments/assets/b6370369-b080-4550-b0d0-830ff22d9068" alt="eraser" style="display:inline; height:1em; vertical-align:middle;"> to restore. </li>
|
75 |
+
<li> ⭐️ <b>3.Input ⌨️ Instructions: </b> Input the instructions. </li>
|
76 |
+
<li> ⭐️ <b>4.Run: </b> Click <b>💫 Run</b> button to automatic edit image. </li>
|
77 |
+
</ul>
|
78 |
+
|
79 |
+
<b> We strongly recommend using GPT-4o for reasoning. </b> After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
|
80 |
+
|
81 |
+
<b> We recommend zooming out in your browser for a better viewing range and experience. </b>
|
82 |
+
|
83 |
+
<b> For more detailed feature descriptions, see the bottom. </b>
|
84 |
+
|
85 |
+
☕️ Have fun! 🎄 Wishing you a merry Christmas!
|
86 |
+
"""
|
87 |
+
|
88 |
+
tips = r"""
|
89 |
+
💡 <b>Some Tips</b>:
|
90 |
+
<ul>
|
91 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Mask</b> button. The mask generated by VLM will be displayed in the preview panel on the right side. </li>
|
92 |
+
<li> 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as <b>randomization</b>, <b>dilation</b>, <b>erosion</b>, and <b>movement</b>. </li>
|
93 |
+
<li> 🤠 After input the instructions, you can click the <b>Generate Target Prompt</b> button. The target prompt will be displayed in the text box, and you can modify it according to your ideas. </li>
|
94 |
+
</ul>
|
95 |
+
|
96 |
+
💡 <b>Detailed Features</b>:
|
97 |
+
<ul>
|
98 |
+
<li> 🎨 <b>Aspect Ratio</b>: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.</li>
|
99 |
+
<li> 🎨 <b>VLM Model</b>: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
|
100 |
+
<li> 🎨 <b>Generate Mask</b>: According to the input instructions, generate a mask for the area that may need to be edited. </li>
|
101 |
+
<li> 🎨 <b>Square/Circle Mask</b>: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.) </li>
|
102 |
+
<li> 🎨 <b>Invert Mask</b>: Invert the mask to generate a new mask. </li>
|
103 |
+
<li> 🎨 <b>Dilation/Erosion Mask</b>: Expand or shrink the mask to include or exclude more areas. </li>
|
104 |
+
<li> 🎨 <b>Move Mask</b>: Move the mask to a new position. </li>
|
105 |
+
<li> 🎨 <b>Generate Target Prompt</b>: Generate a target prompt based on the input instructions. </li>
|
106 |
+
<li> 🎨 <b>Target Prompt</b>: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations. </li>
|
107 |
+
<li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
|
108 |
+
<li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
|
109 |
+
</ul>
|
110 |
+
|
111 |
+
💡 <b>Advanced Features</b>:
|
112 |
+
<ul>
|
113 |
+
<li> 🎨 <b>Base Model</b>: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo. </li>
|
114 |
+
<li> 🎨 <b>Blending</b>: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.) </li>
|
115 |
+
<li> 🎨 <b>Control length</b>: The intensity of editing and inpainting. </li>
|
116 |
+
<li> 🎨 <b>Num samples</b>: The number of samples to generate. </li>
|
117 |
+
<li> 🎨 <b>Negative prompt</b>: The negative prompt for the classifier-free guidance. </li>
|
118 |
+
<li> 🎨 <b>Guidance scale</b>: The guidance scale for the classifier-free guidance. </li>
|
119 |
+
</ul>
|
120 |
+
|
121 |
+
|
122 |
+
"""
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
citation = r"""
|
127 |
+
If BrushEdit is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/BrushEdit' target='_blank'>Github Repo</a>. Thanks!
|
128 |
+
[](https://github.com/TencentARC/BrushEdit)
|
129 |
+
---
|
130 |
+
📝 **Citation**
|
131 |
+
<br>
|
132 |
+
If our work is useful for your research, please consider citing:
|
133 |
+
```bibtex
|
134 |
+
@misc{li2024brushedit,
|
135 |
+
title={BrushEdit: All-In-One Image Inpainting and Editing},
|
136 |
+
author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
|
137 |
+
year={2024},
|
138 |
+
eprint={2412.10316},
|
139 |
+
archivePrefix={arXiv},
|
140 |
+
primaryClass={cs.CV}
|
141 |
+
}
|
142 |
+
```
|
143 |
+
📧 **Contact**
|
144 |
+
<br>
|
145 |
+
If you have any questions, please feel free to reach me out at <b>liyaowei@gmail.com</b>.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# - - - - - examples - - - - - #
|
149 |
+
EXAMPLES = [
|
150 |
+
|
151 |
+
[
|
152 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
153 |
+
"add a magic hat on frog head.",
|
154 |
+
642087011,
|
155 |
+
"frog",
|
156 |
+
"frog",
|
157 |
+
True,
|
158 |
+
False,
|
159 |
+
"GPT4-o (Highly Recommended)"
|
160 |
+
],
|
161 |
+
[
|
162 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
163 |
+
"replace the background to ancient China.",
|
164 |
+
648464818,
|
165 |
+
"chinese_girl",
|
166 |
+
"chinese_girl",
|
167 |
+
True,
|
168 |
+
False,
|
169 |
+
"GPT4-o (Highly Recommended)"
|
170 |
+
],
|
171 |
+
[
|
172 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
173 |
+
"remove the deer.",
|
174 |
+
648464818,
|
175 |
+
"angel_christmas",
|
176 |
+
"angel_christmas",
|
177 |
+
False,
|
178 |
+
False,
|
179 |
+
"GPT4-o (Highly Recommended)"
|
180 |
+
],
|
181 |
+
[
|
182 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
183 |
+
"add a wreath on head.",
|
184 |
+
648464818,
|
185 |
+
"sunflower_girl",
|
186 |
+
"sunflower_girl",
|
187 |
+
True,
|
188 |
+
False,
|
189 |
+
"GPT4-o (Highly Recommended)"
|
190 |
+
],
|
191 |
+
[
|
192 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
193 |
+
"add a butterfly fairy.",
|
194 |
+
648464818,
|
195 |
+
"girl_on_sun",
|
196 |
+
"girl_on_sun",
|
197 |
+
True,
|
198 |
+
False,
|
199 |
+
"GPT4-o (Highly Recommended)"
|
200 |
+
],
|
201 |
+
[
|
202 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
203 |
+
"remove the christmas hat.",
|
204 |
+
642087011,
|
205 |
+
"spider_man_rm",
|
206 |
+
"spider_man_rm",
|
207 |
+
False,
|
208 |
+
False,
|
209 |
+
"GPT4-o (Highly Recommended)"
|
210 |
+
],
|
211 |
+
[
|
212 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
213 |
+
"remove the flower.",
|
214 |
+
642087011,
|
215 |
+
"anime_flower",
|
216 |
+
"anime_flower",
|
217 |
+
False,
|
218 |
+
False,
|
219 |
+
"GPT4-o (Highly Recommended)"
|
220 |
+
],
|
221 |
+
[
|
222 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
223 |
+
"replace the clothes to a delicated floral skirt.",
|
224 |
+
648464818,
|
225 |
+
"chenduling",
|
226 |
+
"chenduling",
|
227 |
+
True,
|
228 |
+
False,
|
229 |
+
"GPT4-o (Highly Recommended)"
|
230 |
+
],
|
231 |
+
[
|
232 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
233 |
+
"make the hedgehog in Italy.",
|
234 |
+
648464818,
|
235 |
+
"hedgehog_rp_bg",
|
236 |
+
"hedgehog_rp_bg",
|
237 |
+
True,
|
238 |
+
False,
|
239 |
+
"GPT4-o (Highly Recommended)"
|
240 |
+
],
|
241 |
+
|
242 |
+
]
|
243 |
+
|
244 |
+
INPUT_IMAGE_PATH = {
|
245 |
+
"frog": "./assets/frog/frog.jpeg",
|
246 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
247 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
248 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
249 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
250 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
251 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
252 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
253 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
254 |
+
}
|
255 |
+
MASK_IMAGE_PATH = {
|
256 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
257 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
258 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
259 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
260 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
261 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
262 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
263 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
264 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
265 |
+
}
|
266 |
+
MASKED_IMAGE_PATH = {
|
267 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
268 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
269 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
270 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
271 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
272 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
273 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
274 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
275 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
276 |
+
}
|
277 |
+
OUTPUT_IMAGE_PATH = {
|
278 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
279 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
280 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
281 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
282 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
283 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
284 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
285 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
286 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
287 |
+
}
|
288 |
+
|
289 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
290 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
291 |
+
|
292 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
293 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
294 |
+
BASE_MODELS = list(base_models_template.keys())
|
295 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
296 |
+
|
297 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
298 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
299 |
+
|
300 |
+
|
301 |
+
## init device
|
302 |
+
try:
|
303 |
+
if torch.cuda.is_available():
|
304 |
+
device = "cuda"
|
305 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
306 |
+
device = "mps"
|
307 |
+
else:
|
308 |
+
device = "cpu"
|
309 |
+
except:
|
310 |
+
device = "cpu"
|
311 |
+
|
312 |
+
# ## init torch dtype
|
313 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
314 |
+
# torch_dtype = torch.bfloat16
|
315 |
+
# else:
|
316 |
+
# torch_dtype = torch.float16
|
317 |
+
|
318 |
+
# if device == "mps":
|
319 |
+
# torch_dtype = torch.float16
|
320 |
+
|
321 |
+
torch_dtype = torch.float16
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
# download hf models
|
326 |
+
BrushEdit_path = "models/"
|
327 |
+
if not os.path.exists(BrushEdit_path):
|
328 |
+
BrushEdit_path = snapshot_download(
|
329 |
+
repo_id="TencentARC/BrushEdit",
|
330 |
+
local_dir=BrushEdit_path,
|
331 |
+
token=os.getenv("HF_TOKEN"),
|
332 |
+
)
|
333 |
+
|
334 |
+
## init default VLM
|
335 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
336 |
+
if vlm_processor != "" and vlm_model != "":
|
337 |
+
vlm_model.to(device)
|
338 |
+
else:
|
339 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
340 |
+
|
341 |
+
|
342 |
+
## init base model
|
343 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
344 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
345 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
346 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
347 |
+
|
348 |
+
|
349 |
+
# input brushnetX ckpt path
|
350 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
351 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
352 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
353 |
+
)
|
354 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
355 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
356 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
357 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
358 |
+
pipe.enable_model_cpu_offload()
|
359 |
+
|
360 |
+
|
361 |
+
## init SAM
|
362 |
+
sam = build_sam(checkpoint=sam_path)
|
363 |
+
sam.to(device=device)
|
364 |
+
sam_predictor = SamPredictor(sam)
|
365 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
366 |
+
|
367 |
+
## init groundingdino_model
|
368 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
369 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
370 |
+
|
371 |
+
## Ordinary function
|
372 |
+
def crop_and_resize(image: Image.Image,
|
373 |
+
target_width: int,
|
374 |
+
target_height: int) -> Image.Image:
|
375 |
+
"""
|
376 |
+
Crops and resizes an image while preserving the aspect ratio.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
380 |
+
target_width (int): Target width of the output image.
|
381 |
+
target_height (int): Target height of the output image.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
Image.Image: Cropped and resized image.
|
385 |
+
"""
|
386 |
+
# Original dimensions
|
387 |
+
original_width, original_height = image.size
|
388 |
+
original_aspect = original_width / original_height
|
389 |
+
target_aspect = target_width / target_height
|
390 |
+
|
391 |
+
# Calculate crop box to maintain aspect ratio
|
392 |
+
if original_aspect > target_aspect:
|
393 |
+
# Crop horizontally
|
394 |
+
new_width = int(original_height * target_aspect)
|
395 |
+
new_height = original_height
|
396 |
+
left = (original_width - new_width) / 2
|
397 |
+
top = 0
|
398 |
+
right = left + new_width
|
399 |
+
bottom = original_height
|
400 |
+
else:
|
401 |
+
# Crop vertically
|
402 |
+
new_width = original_width
|
403 |
+
new_height = int(original_width / target_aspect)
|
404 |
+
left = 0
|
405 |
+
top = (original_height - new_height) / 2
|
406 |
+
right = original_width
|
407 |
+
bottom = top + new_height
|
408 |
+
|
409 |
+
# Crop and resize
|
410 |
+
cropped_image = image.crop((left, top, right, bottom))
|
411 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
412 |
+
return resized_image
|
413 |
+
|
414 |
+
|
415 |
+
## Ordinary function
|
416 |
+
def resize(image: Image.Image,
|
417 |
+
target_width: int,
|
418 |
+
target_height: int) -> Image.Image:
|
419 |
+
"""
|
420 |
+
Crops and resizes an image while preserving the aspect ratio.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
424 |
+
target_width (int): Target width of the output image.
|
425 |
+
target_height (int): Target height of the output image.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
Image.Image: Cropped and resized image.
|
429 |
+
"""
|
430 |
+
# Original dimensions
|
431 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
432 |
+
return resized_image
|
433 |
+
|
434 |
+
|
435 |
+
def move_mask_func(mask, direction, units):
|
436 |
+
binary_mask = mask.squeeze()>0
|
437 |
+
rows, cols = binary_mask.shape
|
438 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
439 |
+
|
440 |
+
if direction == 'down':
|
441 |
+
# move down
|
442 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
443 |
+
|
444 |
+
elif direction == 'up':
|
445 |
+
# move up
|
446 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
447 |
+
|
448 |
+
elif direction == 'right':
|
449 |
+
# move left
|
450 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
451 |
+
|
452 |
+
elif direction == 'left':
|
453 |
+
# move right
|
454 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
455 |
+
|
456 |
+
return moved_mask
|
457 |
+
|
458 |
+
|
459 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
460 |
+
# Randomly select the size of dilation
|
461 |
+
binary_mask = mask.squeeze()>0
|
462 |
+
|
463 |
+
if dilation_type == 'square_dilation':
|
464 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
465 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
466 |
+
elif dilation_type == 'square_erosion':
|
467 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
468 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
469 |
+
elif dilation_type == 'bounding_box':
|
470 |
+
# find the most left top and left bottom point
|
471 |
+
rows, cols = np.where(binary_mask)
|
472 |
+
if len(rows) == 0 or len(cols) == 0:
|
473 |
+
return mask # return original mask if no valid points
|
474 |
+
|
475 |
+
min_row = np.min(rows)
|
476 |
+
max_row = np.max(rows)
|
477 |
+
min_col = np.min(cols)
|
478 |
+
max_col = np.max(cols)
|
479 |
+
|
480 |
+
# create a bounding box
|
481 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
482 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
483 |
+
|
484 |
+
elif dilation_type == 'bounding_ellipse':
|
485 |
+
# find the most left top and left bottom point
|
486 |
+
rows, cols = np.where(binary_mask)
|
487 |
+
if len(rows) == 0 or len(cols) == 0:
|
488 |
+
return mask # return original mask if no valid points
|
489 |
+
|
490 |
+
min_row = np.min(rows)
|
491 |
+
max_row = np.max(rows)
|
492 |
+
min_col = np.min(cols)
|
493 |
+
max_col = np.max(cols)
|
494 |
+
|
495 |
+
# calculate the center and axis length of the ellipse
|
496 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
497 |
+
a = (max_col - min_col) // 2 # half long axis
|
498 |
+
b = (max_row - min_row) // 2 # half short axis
|
499 |
+
|
500 |
+
# create a bounding ellipse
|
501 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
502 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
503 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
504 |
+
dilated_mask[ellipse_mask] = True
|
505 |
+
else:
|
506 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
507 |
+
|
508 |
+
# use binary dilation
|
509 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
510 |
+
return dilated_mask
|
511 |
+
|
512 |
+
|
513 |
+
## Gradio component function
|
514 |
+
def update_vlm_model(vlm_name):
|
515 |
+
global vlm_model, vlm_processor
|
516 |
+
if vlm_model is not None:
|
517 |
+
del vlm_model
|
518 |
+
torch.cuda.empty_cache()
|
519 |
+
|
520 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
521 |
+
|
522 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
523 |
+
if vlm_type == "llava-next":
|
524 |
+
if vlm_processor != "" and vlm_model != "":
|
525 |
+
vlm_model.to(device)
|
526 |
+
return vlm_model_dropdown
|
527 |
+
else:
|
528 |
+
if os.path.exists(vlm_local_path):
|
529 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
530 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
531 |
+
else:
|
532 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
533 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
534 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
535 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
536 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
537 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
538 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
539 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
540 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
541 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
542 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
543 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
544 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
545 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
546 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
547 |
+
elif vlm_type == "qwen2-vl":
|
548 |
+
if vlm_processor != "" and vlm_model != "":
|
549 |
+
vlm_model.to(device)
|
550 |
+
return vlm_model_dropdown
|
551 |
+
else:
|
552 |
+
if os.path.exists(vlm_local_path):
|
553 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
554 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
555 |
+
else:
|
556 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
557 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
558 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
559 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
560 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
561 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
562 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
563 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
564 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
565 |
+
elif vlm_type == "openai":
|
566 |
+
pass
|
567 |
+
return "success"
|
568 |
+
|
569 |
+
|
570 |
+
def update_base_model(base_model_name):
|
571 |
+
global pipe
|
572 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
573 |
+
if pipe is not None:
|
574 |
+
del pipe
|
575 |
+
torch.cuda.empty_cache()
|
576 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
577 |
+
if pipe != "":
|
578 |
+
pipe.to(device)
|
579 |
+
else:
|
580 |
+
if os.path.exists(base_model_path):
|
581 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
582 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
583 |
+
)
|
584 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
585 |
+
pipe.enable_model_cpu_offload()
|
586 |
+
else:
|
587 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
588 |
+
return "success"
|
589 |
+
|
590 |
+
|
591 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
592 |
+
global vlm_model, vlm_processor
|
593 |
+
if vlm_model is not None:
|
594 |
+
del vlm_model
|
595 |
+
torch.cuda.empty_cache()
|
596 |
+
try:
|
597 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY)
|
598 |
+
vlm_processor = ""
|
599 |
+
response = vlm_model.chat.completions.create(
|
600 |
+
model="gpt-4o-2024-08-06",
|
601 |
+
messages=[
|
602 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
603 |
+
{"role": "user", "content": "Say this is a test"}
|
604 |
+
]
|
605 |
+
)
|
606 |
+
response_str = response.choices[0].message.content
|
607 |
+
|
608 |
+
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
|
609 |
+
except Exception as e:
|
610 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
611 |
+
|
612 |
+
|
613 |
+
|
614 |
+
def process(input_image,
|
615 |
+
original_image,
|
616 |
+
original_mask,
|
617 |
+
prompt,
|
618 |
+
negative_prompt,
|
619 |
+
control_strength,
|
620 |
+
seed,
|
621 |
+
randomize_seed,
|
622 |
+
guidance_scale,
|
623 |
+
num_inference_steps,
|
624 |
+
num_samples,
|
625 |
+
blending,
|
626 |
+
category,
|
627 |
+
target_prompt,
|
628 |
+
resize_default,
|
629 |
+
aspect_ratio_name,
|
630 |
+
invert_mask_state):
|
631 |
+
if original_image is None:
|
632 |
+
if input_image is None:
|
633 |
+
raise gr.Error('Please upload the input image')
|
634 |
+
else:
|
635 |
+
image_pil = input_image["background"].convert("RGB")
|
636 |
+
original_image = np.array(image_pil)
|
637 |
+
if prompt is None or prompt == "":
|
638 |
+
if target_prompt is None or target_prompt == "":
|
639 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
640 |
+
|
641 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
642 |
+
input_mask = np.asarray(alpha_mask)
|
643 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
644 |
+
if output_w == "" or output_h == "":
|
645 |
+
output_h, output_w = original_image.shape[:2]
|
646 |
+
|
647 |
+
if resize_default:
|
648 |
+
short_side = min(output_w, output_h)
|
649 |
+
scale_ratio = 640 / short_side
|
650 |
+
output_w = int(output_w * scale_ratio)
|
651 |
+
output_h = int(output_h * scale_ratio)
|
652 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
653 |
+
original_image = np.array(original_image)
|
654 |
+
if input_mask is not None:
|
655 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
656 |
+
input_mask = np.array(input_mask)
|
657 |
+
if original_mask is not None:
|
658 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
659 |
+
original_mask = np.array(original_mask)
|
660 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
661 |
+
else:
|
662 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
663 |
+
pass
|
664 |
+
else:
|
665 |
+
if resize_default:
|
666 |
+
short_side = min(output_w, output_h)
|
667 |
+
scale_ratio = 640 / short_side
|
668 |
+
output_w = int(output_w * scale_ratio)
|
669 |
+
output_h = int(output_h * scale_ratio)
|
670 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
671 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
672 |
+
original_image = np.array(original_image)
|
673 |
+
if input_mask is not None:
|
674 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
675 |
+
input_mask = np.array(input_mask)
|
676 |
+
if original_mask is not None:
|
677 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
678 |
+
original_mask = np.array(original_mask)
|
679 |
+
|
680 |
+
if invert_mask_state:
|
681 |
+
original_mask = original_mask
|
682 |
+
else:
|
683 |
+
if input_mask.max() == 0:
|
684 |
+
original_mask = original_mask
|
685 |
+
else:
|
686 |
+
original_mask = input_mask
|
687 |
+
|
688 |
+
|
689 |
+
## inpainting directly if target_prompt is not None
|
690 |
+
if category is not None:
|
691 |
+
pass
|
692 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
693 |
+
pass
|
694 |
+
else:
|
695 |
+
try:
|
696 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
697 |
+
except Exception as e:
|
698 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
699 |
+
|
700 |
+
|
701 |
+
if original_mask is not None:
|
702 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
703 |
+
else:
|
704 |
+
try:
|
705 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
706 |
+
vlm_processor,
|
707 |
+
vlm_model,
|
708 |
+
original_image,
|
709 |
+
category,
|
710 |
+
prompt,
|
711 |
+
device)
|
712 |
+
|
713 |
+
original_mask = vlm_response_mask(vlm_processor,
|
714 |
+
vlm_model,
|
715 |
+
category,
|
716 |
+
original_image,
|
717 |
+
prompt,
|
718 |
+
object_wait_for_edit,
|
719 |
+
sam,
|
720 |
+
sam_predictor,
|
721 |
+
sam_automask_generator,
|
722 |
+
groundingdino_model,
|
723 |
+
device).astype(np.uint8)
|
724 |
+
except Exception as e:
|
725 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
726 |
+
|
727 |
+
if original_mask.ndim == 2:
|
728 |
+
original_mask = original_mask[:,:,None]
|
729 |
+
|
730 |
+
|
731 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
732 |
+
prompt_after_apply_instruction = target_prompt
|
733 |
+
|
734 |
+
else:
|
735 |
+
try:
|
736 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
737 |
+
vlm_processor,
|
738 |
+
vlm_model,
|
739 |
+
original_image,
|
740 |
+
prompt,
|
741 |
+
device)
|
742 |
+
except Exception as e:
|
743 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
744 |
+
|
745 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
746 |
+
|
747 |
+
|
748 |
+
with torch.autocast(device):
|
749 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
750 |
+
prompt_after_apply_instruction,
|
751 |
+
original_mask,
|
752 |
+
original_image,
|
753 |
+
generator,
|
754 |
+
num_inference_steps,
|
755 |
+
guidance_scale,
|
756 |
+
control_strength,
|
757 |
+
negative_prompt,
|
758 |
+
num_samples,
|
759 |
+
blending)
|
760 |
+
original_image = np.array(init_image_np)
|
761 |
+
masked_image = original_image * (1 - (mask_np>0))
|
762 |
+
masked_image = masked_image.astype(np.uint8)
|
763 |
+
masked_image = Image.fromarray(masked_image)
|
764 |
+
# Save the images (optional)
|
765 |
+
# import uuid
|
766 |
+
# uuid = str(uuid.uuid4())
|
767 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
768 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
769 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
770 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
771 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
772 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
773 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
774 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
775 |
+
|
776 |
+
|
777 |
+
def generate_target_prompt(input_image,
|
778 |
+
original_image,
|
779 |
+
prompt):
|
780 |
+
# load example image
|
781 |
+
if isinstance(original_image, str):
|
782 |
+
original_image = input_image
|
783 |
+
|
784 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
785 |
+
vlm_processor,
|
786 |
+
vlm_model,
|
787 |
+
original_image,
|
788 |
+
prompt,
|
789 |
+
device)
|
790 |
+
return prompt_after_apply_instruction
|
791 |
+
|
792 |
+
|
793 |
+
def process_mask(input_image,
|
794 |
+
original_image,
|
795 |
+
prompt,
|
796 |
+
resize_default,
|
797 |
+
aspect_ratio_name):
|
798 |
+
if original_image is None:
|
799 |
+
raise gr.Error('Please upload the input image')
|
800 |
+
if prompt is None:
|
801 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
802 |
+
|
803 |
+
## load mask
|
804 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
805 |
+
input_mask = np.array(alpha_mask)
|
806 |
+
|
807 |
+
# load example image
|
808 |
+
if isinstance(original_image, str):
|
809 |
+
original_image = input_image["background"]
|
810 |
+
|
811 |
+
if input_mask.max() == 0:
|
812 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
813 |
+
|
814 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
815 |
+
vlm_model,
|
816 |
+
original_image,
|
817 |
+
category,
|
818 |
+
prompt,
|
819 |
+
device)
|
820 |
+
# original mask: h,w,1 [0, 255]
|
821 |
+
original_mask = vlm_response_mask(
|
822 |
+
vlm_processor,
|
823 |
+
vlm_model,
|
824 |
+
category,
|
825 |
+
original_image,
|
826 |
+
prompt,
|
827 |
+
object_wait_for_edit,
|
828 |
+
sam,
|
829 |
+
sam_predictor,
|
830 |
+
sam_automask_generator,
|
831 |
+
groundingdino_model,
|
832 |
+
device).astype(np.uint8)
|
833 |
+
else:
|
834 |
+
original_mask = input_mask.astype(np.uint8)
|
835 |
+
category = None
|
836 |
+
|
837 |
+
## resize mask if needed
|
838 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
839 |
+
if output_w == "" or output_h == "":
|
840 |
+
output_h, output_w = original_image.shape[:2]
|
841 |
+
if resize_default:
|
842 |
+
short_side = min(output_w, output_h)
|
843 |
+
scale_ratio = 640 / short_side
|
844 |
+
output_w = int(output_w * scale_ratio)
|
845 |
+
output_h = int(output_h * scale_ratio)
|
846 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
847 |
+
original_image = np.array(original_image)
|
848 |
+
if input_mask is not None:
|
849 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
850 |
+
input_mask = np.array(input_mask)
|
851 |
+
if original_mask is not None:
|
852 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
853 |
+
original_mask = np.array(original_mask)
|
854 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
855 |
+
else:
|
856 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
857 |
+
pass
|
858 |
+
else:
|
859 |
+
if resize_default:
|
860 |
+
short_side = min(output_w, output_h)
|
861 |
+
scale_ratio = 640 / short_side
|
862 |
+
output_w = int(output_w * scale_ratio)
|
863 |
+
output_h = int(output_h * scale_ratio)
|
864 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
865 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
866 |
+
original_image = np.array(original_image)
|
867 |
+
if input_mask is not None:
|
868 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
869 |
+
input_mask = np.array(input_mask)
|
870 |
+
if original_mask is not None:
|
871 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
872 |
+
original_mask = np.array(original_mask)
|
873 |
+
|
874 |
+
|
875 |
+
if original_mask.ndim == 2:
|
876 |
+
original_mask = original_mask[:,:,None]
|
877 |
+
|
878 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
879 |
+
|
880 |
+
masked_image = original_image * (1 - (original_mask>0))
|
881 |
+
masked_image = masked_image.astype(np.uint8)
|
882 |
+
masked_image = Image.fromarray(masked_image)
|
883 |
+
|
884 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
885 |
+
|
886 |
+
|
887 |
+
def process_random_mask(input_image,
|
888 |
+
original_image,
|
889 |
+
original_mask,
|
890 |
+
resize_default,
|
891 |
+
aspect_ratio_name,
|
892 |
+
):
|
893 |
+
|
894 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
895 |
+
input_mask = np.asarray(alpha_mask)
|
896 |
+
|
897 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
898 |
+
if output_w == "" or output_h == "":
|
899 |
+
output_h, output_w = original_image.shape[:2]
|
900 |
+
if resize_default:
|
901 |
+
short_side = min(output_w, output_h)
|
902 |
+
scale_ratio = 640 / short_side
|
903 |
+
output_w = int(output_w * scale_ratio)
|
904 |
+
output_h = int(output_h * scale_ratio)
|
905 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
906 |
+
original_image = np.array(original_image)
|
907 |
+
if input_mask is not None:
|
908 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
909 |
+
input_mask = np.array(input_mask)
|
910 |
+
if original_mask is not None:
|
911 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
912 |
+
original_mask = np.array(original_mask)
|
913 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
914 |
+
else:
|
915 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
916 |
+
pass
|
917 |
+
else:
|
918 |
+
if resize_default:
|
919 |
+
short_side = min(output_w, output_h)
|
920 |
+
scale_ratio = 640 / short_side
|
921 |
+
output_w = int(output_w * scale_ratio)
|
922 |
+
output_h = int(output_h * scale_ratio)
|
923 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
924 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
925 |
+
original_image = np.array(original_image)
|
926 |
+
if input_mask is not None:
|
927 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
928 |
+
input_mask = np.array(input_mask)
|
929 |
+
if original_mask is not None:
|
930 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
931 |
+
original_mask = np.array(original_mask)
|
932 |
+
|
933 |
+
|
934 |
+
if input_mask.max() == 0:
|
935 |
+
original_mask = original_mask
|
936 |
+
else:
|
937 |
+
original_mask = input_mask
|
938 |
+
|
939 |
+
if original_mask is None:
|
940 |
+
raise gr.Error('Please generate mask first')
|
941 |
+
|
942 |
+
if original_mask.ndim == 2:
|
943 |
+
original_mask = original_mask[:,:,None]
|
944 |
+
|
945 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
946 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
947 |
+
|
948 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
949 |
+
|
950 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
951 |
+
masked_image = masked_image.astype(original_image.dtype)
|
952 |
+
masked_image = Image.fromarray(masked_image)
|
953 |
+
|
954 |
+
|
955 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
956 |
+
|
957 |
+
|
958 |
+
def process_dilation_mask(input_image,
|
959 |
+
original_image,
|
960 |
+
original_mask,
|
961 |
+
resize_default,
|
962 |
+
aspect_ratio_name,
|
963 |
+
dilation_size=20):
|
964 |
+
|
965 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
966 |
+
input_mask = np.asarray(alpha_mask)
|
967 |
+
|
968 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
969 |
+
if output_w == "" or output_h == "":
|
970 |
+
output_h, output_w = original_image.shape[:2]
|
971 |
+
if resize_default:
|
972 |
+
short_side = min(output_w, output_h)
|
973 |
+
scale_ratio = 640 / short_side
|
974 |
+
output_w = int(output_w * scale_ratio)
|
975 |
+
output_h = int(output_h * scale_ratio)
|
976 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
977 |
+
original_image = np.array(original_image)
|
978 |
+
if input_mask is not None:
|
979 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
980 |
+
input_mask = np.array(input_mask)
|
981 |
+
if original_mask is not None:
|
982 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
983 |
+
original_mask = np.array(original_mask)
|
984 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
985 |
+
else:
|
986 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
987 |
+
pass
|
988 |
+
else:
|
989 |
+
if resize_default:
|
990 |
+
short_side = min(output_w, output_h)
|
991 |
+
scale_ratio = 640 / short_side
|
992 |
+
output_w = int(output_w * scale_ratio)
|
993 |
+
output_h = int(output_h * scale_ratio)
|
994 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
995 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
996 |
+
original_image = np.array(original_image)
|
997 |
+
if input_mask is not None:
|
998 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
999 |
+
input_mask = np.array(input_mask)
|
1000 |
+
if original_mask is not None:
|
1001 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1002 |
+
original_mask = np.array(original_mask)
|
1003 |
+
|
1004 |
+
if input_mask.max() == 0:
|
1005 |
+
original_mask = original_mask
|
1006 |
+
else:
|
1007 |
+
original_mask = input_mask
|
1008 |
+
|
1009 |
+
if original_mask is None:
|
1010 |
+
raise gr.Error('Please generate mask first')
|
1011 |
+
|
1012 |
+
if original_mask.ndim == 2:
|
1013 |
+
original_mask = original_mask[:,:,None]
|
1014 |
+
|
1015 |
+
dilation_type = np.random.choice(['square_dilation'])
|
1016 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1017 |
+
|
1018 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1019 |
+
|
1020 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1021 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1022 |
+
masked_image = Image.fromarray(masked_image)
|
1023 |
+
|
1024 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1025 |
+
|
1026 |
+
|
1027 |
+
def process_erosion_mask(input_image,
|
1028 |
+
original_image,
|
1029 |
+
original_mask,
|
1030 |
+
resize_default,
|
1031 |
+
aspect_ratio_name,
|
1032 |
+
dilation_size=20):
|
1033 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1034 |
+
input_mask = np.asarray(alpha_mask)
|
1035 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1036 |
+
if output_w == "" or output_h == "":
|
1037 |
+
output_h, output_w = original_image.shape[:2]
|
1038 |
+
if resize_default:
|
1039 |
+
short_side = min(output_w, output_h)
|
1040 |
+
scale_ratio = 640 / short_side
|
1041 |
+
output_w = int(output_w * scale_ratio)
|
1042 |
+
output_h = int(output_h * scale_ratio)
|
1043 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1044 |
+
original_image = np.array(original_image)
|
1045 |
+
if input_mask is not None:
|
1046 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1047 |
+
input_mask = np.array(input_mask)
|
1048 |
+
if original_mask is not None:
|
1049 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1050 |
+
original_mask = np.array(original_mask)
|
1051 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1052 |
+
else:
|
1053 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1054 |
+
pass
|
1055 |
+
else:
|
1056 |
+
if resize_default:
|
1057 |
+
short_side = min(output_w, output_h)
|
1058 |
+
scale_ratio = 640 / short_side
|
1059 |
+
output_w = int(output_w * scale_ratio)
|
1060 |
+
output_h = int(output_h * scale_ratio)
|
1061 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1062 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1063 |
+
original_image = np.array(original_image)
|
1064 |
+
if input_mask is not None:
|
1065 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1066 |
+
input_mask = np.array(input_mask)
|
1067 |
+
if original_mask is not None:
|
1068 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1069 |
+
original_mask = np.array(original_mask)
|
1070 |
+
|
1071 |
+
if input_mask.max() == 0:
|
1072 |
+
original_mask = original_mask
|
1073 |
+
else:
|
1074 |
+
original_mask = input_mask
|
1075 |
+
|
1076 |
+
if original_mask is None:
|
1077 |
+
raise gr.Error('Please generate mask first')
|
1078 |
+
|
1079 |
+
if original_mask.ndim == 2:
|
1080 |
+
original_mask = original_mask[:,:,None]
|
1081 |
+
|
1082 |
+
dilation_type = np.random.choice(['square_erosion'])
|
1083 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1084 |
+
|
1085 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1086 |
+
|
1087 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1088 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1089 |
+
masked_image = Image.fromarray(masked_image)
|
1090 |
+
|
1091 |
+
|
1092 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1093 |
+
|
1094 |
+
|
1095 |
+
def move_mask_left(input_image,
|
1096 |
+
original_image,
|
1097 |
+
original_mask,
|
1098 |
+
moving_pixels,
|
1099 |
+
resize_default,
|
1100 |
+
aspect_ratio_name):
|
1101 |
+
|
1102 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1103 |
+
input_mask = np.asarray(alpha_mask)
|
1104 |
+
|
1105 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1106 |
+
if output_w == "" or output_h == "":
|
1107 |
+
output_h, output_w = original_image.shape[:2]
|
1108 |
+
if resize_default:
|
1109 |
+
short_side = min(output_w, output_h)
|
1110 |
+
scale_ratio = 640 / short_side
|
1111 |
+
output_w = int(output_w * scale_ratio)
|
1112 |
+
output_h = int(output_h * scale_ratio)
|
1113 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1114 |
+
original_image = np.array(original_image)
|
1115 |
+
if input_mask is not None:
|
1116 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1117 |
+
input_mask = np.array(input_mask)
|
1118 |
+
if original_mask is not None:
|
1119 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1120 |
+
original_mask = np.array(original_mask)
|
1121 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1122 |
+
else:
|
1123 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1124 |
+
pass
|
1125 |
+
else:
|
1126 |
+
if resize_default:
|
1127 |
+
short_side = min(output_w, output_h)
|
1128 |
+
scale_ratio = 640 / short_side
|
1129 |
+
output_w = int(output_w * scale_ratio)
|
1130 |
+
output_h = int(output_h * scale_ratio)
|
1131 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1132 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1133 |
+
original_image = np.array(original_image)
|
1134 |
+
if input_mask is not None:
|
1135 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1136 |
+
input_mask = np.array(input_mask)
|
1137 |
+
if original_mask is not None:
|
1138 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1139 |
+
original_mask = np.array(original_mask)
|
1140 |
+
|
1141 |
+
if input_mask.max() == 0:
|
1142 |
+
original_mask = original_mask
|
1143 |
+
else:
|
1144 |
+
original_mask = input_mask
|
1145 |
+
|
1146 |
+
if original_mask is None:
|
1147 |
+
raise gr.Error('Please generate mask first')
|
1148 |
+
|
1149 |
+
if original_mask.ndim == 2:
|
1150 |
+
original_mask = original_mask[:,:,None]
|
1151 |
+
|
1152 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1153 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1154 |
+
|
1155 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1156 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1157 |
+
masked_image = Image.fromarray(masked_image)
|
1158 |
+
|
1159 |
+
if moved_mask.max() <= 1:
|
1160 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1161 |
+
original_mask = moved_mask
|
1162 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1163 |
+
|
1164 |
+
|
1165 |
+
def move_mask_right(input_image,
|
1166 |
+
original_image,
|
1167 |
+
original_mask,
|
1168 |
+
moving_pixels,
|
1169 |
+
resize_default,
|
1170 |
+
aspect_ratio_name):
|
1171 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1172 |
+
input_mask = np.asarray(alpha_mask)
|
1173 |
+
|
1174 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1175 |
+
if output_w == "" or output_h == "":
|
1176 |
+
output_h, output_w = original_image.shape[:2]
|
1177 |
+
if resize_default:
|
1178 |
+
short_side = min(output_w, output_h)
|
1179 |
+
scale_ratio = 640 / short_side
|
1180 |
+
output_w = int(output_w * scale_ratio)
|
1181 |
+
output_h = int(output_h * scale_ratio)
|
1182 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1183 |
+
original_image = np.array(original_image)
|
1184 |
+
if input_mask is not None:
|
1185 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1186 |
+
input_mask = np.array(input_mask)
|
1187 |
+
if original_mask is not None:
|
1188 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1189 |
+
original_mask = np.array(original_mask)
|
1190 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1191 |
+
else:
|
1192 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1193 |
+
pass
|
1194 |
+
else:
|
1195 |
+
if resize_default:
|
1196 |
+
short_side = min(output_w, output_h)
|
1197 |
+
scale_ratio = 640 / short_side
|
1198 |
+
output_w = int(output_w * scale_ratio)
|
1199 |
+
output_h = int(output_h * scale_ratio)
|
1200 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1201 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1202 |
+
original_image = np.array(original_image)
|
1203 |
+
if input_mask is not None:
|
1204 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1205 |
+
input_mask = np.array(input_mask)
|
1206 |
+
if original_mask is not None:
|
1207 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1208 |
+
original_mask = np.array(original_mask)
|
1209 |
+
|
1210 |
+
if input_mask.max() == 0:
|
1211 |
+
original_mask = original_mask
|
1212 |
+
else:
|
1213 |
+
original_mask = input_mask
|
1214 |
+
|
1215 |
+
if original_mask is None:
|
1216 |
+
raise gr.Error('Please generate mask first')
|
1217 |
+
|
1218 |
+
if original_mask.ndim == 2:
|
1219 |
+
original_mask = original_mask[:,:,None]
|
1220 |
+
|
1221 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1222 |
+
|
1223 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1224 |
+
|
1225 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1226 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1227 |
+
masked_image = Image.fromarray(masked_image)
|
1228 |
+
|
1229 |
+
|
1230 |
+
if moved_mask.max() <= 1:
|
1231 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1232 |
+
original_mask = moved_mask
|
1233 |
+
|
1234 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1235 |
+
|
1236 |
+
|
1237 |
+
def move_mask_up(input_image,
|
1238 |
+
original_image,
|
1239 |
+
original_mask,
|
1240 |
+
moving_pixels,
|
1241 |
+
resize_default,
|
1242 |
+
aspect_ratio_name):
|
1243 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1244 |
+
input_mask = np.asarray(alpha_mask)
|
1245 |
+
|
1246 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1247 |
+
if output_w == "" or output_h == "":
|
1248 |
+
output_h, output_w = original_image.shape[:2]
|
1249 |
+
if resize_default:
|
1250 |
+
short_side = min(output_w, output_h)
|
1251 |
+
scale_ratio = 640 / short_side
|
1252 |
+
output_w = int(output_w * scale_ratio)
|
1253 |
+
output_h = int(output_h * scale_ratio)
|
1254 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1255 |
+
original_image = np.array(original_image)
|
1256 |
+
if input_mask is not None:
|
1257 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1258 |
+
input_mask = np.array(input_mask)
|
1259 |
+
if original_mask is not None:
|
1260 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1261 |
+
original_mask = np.array(original_mask)
|
1262 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1263 |
+
else:
|
1264 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1265 |
+
pass
|
1266 |
+
else:
|
1267 |
+
if resize_default:
|
1268 |
+
short_side = min(output_w, output_h)
|
1269 |
+
scale_ratio = 640 / short_side
|
1270 |
+
output_w = int(output_w * scale_ratio)
|
1271 |
+
output_h = int(output_h * scale_ratio)
|
1272 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1273 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1274 |
+
original_image = np.array(original_image)
|
1275 |
+
if input_mask is not None:
|
1276 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1277 |
+
input_mask = np.array(input_mask)
|
1278 |
+
if original_mask is not None:
|
1279 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1280 |
+
original_mask = np.array(original_mask)
|
1281 |
+
|
1282 |
+
if input_mask.max() == 0:
|
1283 |
+
original_mask = original_mask
|
1284 |
+
else:
|
1285 |
+
original_mask = input_mask
|
1286 |
+
|
1287 |
+
if original_mask is None:
|
1288 |
+
raise gr.Error('Please generate mask first')
|
1289 |
+
|
1290 |
+
if original_mask.ndim == 2:
|
1291 |
+
original_mask = original_mask[:,:,None]
|
1292 |
+
|
1293 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1294 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1295 |
+
|
1296 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1297 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1298 |
+
masked_image = Image.fromarray(masked_image)
|
1299 |
+
|
1300 |
+
if moved_mask.max() <= 1:
|
1301 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1302 |
+
original_mask = moved_mask
|
1303 |
+
|
1304 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1305 |
+
|
1306 |
+
|
1307 |
+
def move_mask_down(input_image,
|
1308 |
+
original_image,
|
1309 |
+
original_mask,
|
1310 |
+
moving_pixels,
|
1311 |
+
resize_default,
|
1312 |
+
aspect_ratio_name):
|
1313 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1314 |
+
input_mask = np.asarray(alpha_mask)
|
1315 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1316 |
+
if output_w == "" or output_h == "":
|
1317 |
+
output_h, output_w = original_image.shape[:2]
|
1318 |
+
if resize_default:
|
1319 |
+
short_side = min(output_w, output_h)
|
1320 |
+
scale_ratio = 640 / short_side
|
1321 |
+
output_w = int(output_w * scale_ratio)
|
1322 |
+
output_h = int(output_h * scale_ratio)
|
1323 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1324 |
+
original_image = np.array(original_image)
|
1325 |
+
if input_mask is not None:
|
1326 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1327 |
+
input_mask = np.array(input_mask)
|
1328 |
+
if original_mask is not None:
|
1329 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1330 |
+
original_mask = np.array(original_mask)
|
1331 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1332 |
+
else:
|
1333 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1334 |
+
pass
|
1335 |
+
else:
|
1336 |
+
if resize_default:
|
1337 |
+
short_side = min(output_w, output_h)
|
1338 |
+
scale_ratio = 640 / short_side
|
1339 |
+
output_w = int(output_w * scale_ratio)
|
1340 |
+
output_h = int(output_h * scale_ratio)
|
1341 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1342 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1343 |
+
original_image = np.array(original_image)
|
1344 |
+
if input_mask is not None:
|
1345 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1346 |
+
input_mask = np.array(input_mask)
|
1347 |
+
if original_mask is not None:
|
1348 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1349 |
+
original_mask = np.array(original_mask)
|
1350 |
+
|
1351 |
+
if input_mask.max() == 0:
|
1352 |
+
original_mask = original_mask
|
1353 |
+
else:
|
1354 |
+
original_mask = input_mask
|
1355 |
+
|
1356 |
+
if original_mask is None:
|
1357 |
+
raise gr.Error('Please generate mask first')
|
1358 |
+
|
1359 |
+
if original_mask.ndim == 2:
|
1360 |
+
original_mask = original_mask[:,:,None]
|
1361 |
+
|
1362 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1363 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1364 |
+
|
1365 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1366 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1367 |
+
masked_image = Image.fromarray(masked_image)
|
1368 |
+
|
1369 |
+
if moved_mask.max() <= 1:
|
1370 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1371 |
+
original_mask = moved_mask
|
1372 |
+
|
1373 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1374 |
+
|
1375 |
+
|
1376 |
+
def invert_mask(input_image,
|
1377 |
+
original_image,
|
1378 |
+
original_mask,
|
1379 |
+
):
|
1380 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1381 |
+
input_mask = np.asarray(alpha_mask)
|
1382 |
+
if input_mask.max() == 0:
|
1383 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1384 |
+
else:
|
1385 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1386 |
+
|
1387 |
+
if original_mask is None:
|
1388 |
+
raise gr.Error('Please generate mask first')
|
1389 |
+
|
1390 |
+
original_mask = original_mask.squeeze()
|
1391 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1392 |
+
|
1393 |
+
if original_mask.ndim == 2:
|
1394 |
+
original_mask = original_mask[:,:,None]
|
1395 |
+
|
1396 |
+
if original_mask.max() <= 1:
|
1397 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1398 |
+
|
1399 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1400 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1401 |
+
masked_image = Image.fromarray(masked_image)
|
1402 |
+
|
1403 |
+
return [masked_image], [mask_image], original_mask, True
|
1404 |
+
|
1405 |
+
|
1406 |
+
def init_img(base,
|
1407 |
+
init_type,
|
1408 |
+
prompt,
|
1409 |
+
aspect_ratio,
|
1410 |
+
example_change_times
|
1411 |
+
):
|
1412 |
+
image_pil = base["background"].convert("RGB")
|
1413 |
+
original_image = np.array(image_pil)
|
1414 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1415 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1416 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1417 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1418 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1419 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1420 |
+
width, height = image_pil.size
|
1421 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1422 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1423 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1424 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1425 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1426 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1427 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1428 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1429 |
+
else:
|
1430 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1431 |
+
aspect_ratio = "Custom resolution"
|
1432 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1433 |
+
|
1434 |
+
|
1435 |
+
def reset_func(input_image,
|
1436 |
+
original_image,
|
1437 |
+
original_mask,
|
1438 |
+
prompt,
|
1439 |
+
target_prompt,
|
1440 |
+
):
|
1441 |
+
input_image = None
|
1442 |
+
original_image = None
|
1443 |
+
original_mask = None
|
1444 |
+
prompt = ''
|
1445 |
+
mask_gallery = []
|
1446 |
+
masked_gallery = []
|
1447 |
+
result_gallery = []
|
1448 |
+
target_prompt = ''
|
1449 |
+
if torch.cuda.is_available():
|
1450 |
+
torch.cuda.empty_cache()
|
1451 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1452 |
+
|
1453 |
+
|
1454 |
+
def update_example(example_type,
|
1455 |
+
prompt,
|
1456 |
+
example_change_times):
|
1457 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1458 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1459 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1460 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1461 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1462 |
+
width, height = image_pil.size
|
1463 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1464 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1465 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1466 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1467 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1468 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1469 |
+
|
1470 |
+
original_image = np.array(image_pil)
|
1471 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1472 |
+
aspect_ratio = "Custom resolution"
|
1473 |
+
example_change_times += 1
|
1474 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1475 |
+
|
1476 |
+
|
1477 |
+
block = gr.Blocks(
|
1478 |
+
theme=gr.themes.Soft(
|
1479 |
+
radius_size=gr.themes.sizes.radius_none,
|
1480 |
+
text_size=gr.themes.sizes.text_md
|
1481 |
+
)
|
1482 |
+
)
|
1483 |
+
with block as demo:
|
1484 |
+
with gr.Row():
|
1485 |
+
with gr.Column():
|
1486 |
+
gr.HTML(head)
|
1487 |
+
|
1488 |
+
gr.Markdown(descriptions)
|
1489 |
+
|
1490 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1491 |
+
with gr.Row(equal_height=True):
|
1492 |
+
gr.Markdown(instructions)
|
1493 |
+
|
1494 |
+
original_image = gr.State(value=None)
|
1495 |
+
original_mask = gr.State(value=None)
|
1496 |
+
category = gr.State(value=None)
|
1497 |
+
status = gr.State(value=None)
|
1498 |
+
invert_mask_state = gr.State(value=False)
|
1499 |
+
example_change_times = gr.State(value=0)
|
1500 |
+
|
1501 |
+
|
1502 |
+
with gr.Row():
|
1503 |
+
with gr.Column():
|
1504 |
+
with gr.Row():
|
1505 |
+
input_image = gr.ImageEditor(
|
1506 |
+
label="Input Image",
|
1507 |
+
type="pil",
|
1508 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1509 |
+
layers = False,
|
1510 |
+
interactive=True,
|
1511 |
+
height=1024,
|
1512 |
+
sources=["upload"],
|
1513 |
+
placeholder="Please click here or the icon below to upload the image.",
|
1514 |
+
)
|
1515 |
+
|
1516 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1517 |
+
run_button = gr.Button("💫 Run")
|
1518 |
+
|
1519 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1520 |
+
with gr.Group():
|
1521 |
+
with gr.Row():
|
1522 |
+
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1523 |
+
|
1524 |
+
GPT4o_KEY_submit = gr.Button("Submit and Verify")
|
1525 |
+
|
1526 |
+
|
1527 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1528 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1529 |
+
|
1530 |
+
with gr.Row():
|
1531 |
+
mask_button = gr.Button("Generate Mask")
|
1532 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1533 |
+
|
1534 |
+
|
1535 |
+
with gr.Row():
|
1536 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1537 |
+
|
1538 |
+
target_prompt = gr.Text(
|
1539 |
+
label="Input Target Prompt",
|
1540 |
+
max_lines=5,
|
1541 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1542 |
+
value='',
|
1543 |
+
lines=2
|
1544 |
+
)
|
1545 |
+
|
1546 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1547 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1548 |
+
negative_prompt = gr.Text(
|
1549 |
+
label="Negative Prompt",
|
1550 |
+
max_lines=5,
|
1551 |
+
placeholder="Please input your negative prompt",
|
1552 |
+
value='ugly, low quality',lines=1
|
1553 |
+
)
|
1554 |
+
|
1555 |
+
control_strength = gr.Slider(
|
1556 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1557 |
+
)
|
1558 |
+
with gr.Group():
|
1559 |
+
seed = gr.Slider(
|
1560 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1561 |
+
)
|
1562 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1563 |
+
|
1564 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1565 |
+
|
1566 |
+
|
1567 |
+
num_samples = gr.Slider(
|
1568 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1569 |
+
)
|
1570 |
+
|
1571 |
+
with gr.Group():
|
1572 |
+
with gr.Row():
|
1573 |
+
guidance_scale = gr.Slider(
|
1574 |
+
label="Guidance scale",
|
1575 |
+
minimum=1,
|
1576 |
+
maximum=12,
|
1577 |
+
step=0.1,
|
1578 |
+
value=7.5,
|
1579 |
+
)
|
1580 |
+
num_inference_steps = gr.Slider(
|
1581 |
+
label="Number of inference steps",
|
1582 |
+
minimum=1,
|
1583 |
+
maximum=50,
|
1584 |
+
step=1,
|
1585 |
+
value=50,
|
1586 |
+
)
|
1587 |
+
|
1588 |
+
|
1589 |
+
with gr.Column():
|
1590 |
+
with gr.Row():
|
1591 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1592 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1593 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1594 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1595 |
+
|
1596 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1597 |
+
dilation_size = gr.Slider(
|
1598 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1599 |
+
)
|
1600 |
+
with gr.Row():
|
1601 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1602 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1603 |
+
|
1604 |
+
moving_pixels = gr.Slider(
|
1605 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1606 |
+
)
|
1607 |
+
with gr.Row():
|
1608 |
+
move_left_button = gr.Button("Move Left")
|
1609 |
+
move_right_button = gr.Button("Move Right")
|
1610 |
+
with gr.Row():
|
1611 |
+
move_up_button = gr.Button("Move Up")
|
1612 |
+
move_down_button = gr.Button("Move Down")
|
1613 |
+
|
1614 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1615 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1616 |
+
|
1617 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1618 |
+
|
1619 |
+
reset_button = gr.Button("Reset")
|
1620 |
+
|
1621 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1622 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1623 |
+
|
1624 |
+
|
1625 |
+
|
1626 |
+
with gr.Row():
|
1627 |
+
example = gr.Examples(
|
1628 |
+
label="Quick Example",
|
1629 |
+
examples=EXAMPLES,
|
1630 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1631 |
+
examples_per_page=10,
|
1632 |
+
cache_examples=False,
|
1633 |
+
)
|
1634 |
+
|
1635 |
+
|
1636 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1637 |
+
with gr.Row(equal_height=True):
|
1638 |
+
gr.Markdown(tips)
|
1639 |
+
|
1640 |
+
with gr.Row():
|
1641 |
+
gr.Markdown(citation)
|
1642 |
+
|
1643 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1644 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1645 |
+
input_image.upload(
|
1646 |
+
init_img,
|
1647 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1648 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1649 |
+
)
|
1650 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1651 |
+
|
1652 |
+
## vlm and base model dropdown
|
1653 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1654 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1655 |
+
|
1656 |
+
|
1657 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1658 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1659 |
+
|
1660 |
+
|
1661 |
+
ips=[input_image,
|
1662 |
+
original_image,
|
1663 |
+
original_mask,
|
1664 |
+
prompt,
|
1665 |
+
negative_prompt,
|
1666 |
+
control_strength,
|
1667 |
+
seed,
|
1668 |
+
randomize_seed,
|
1669 |
+
guidance_scale,
|
1670 |
+
num_inference_steps,
|
1671 |
+
num_samples,
|
1672 |
+
blending,
|
1673 |
+
category,
|
1674 |
+
target_prompt,
|
1675 |
+
resize_default,
|
1676 |
+
aspect_ratio,
|
1677 |
+
invert_mask_state]
|
1678 |
+
|
1679 |
+
## run brushedit
|
1680 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1681 |
+
|
1682 |
+
## mask func
|
1683 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1684 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1685 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1686 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1687 |
+
|
1688 |
+
## move mask func
|
1689 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1690 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1691 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1692 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1693 |
+
|
1694 |
+
## prompt func
|
1695 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1696 |
+
|
1697 |
+
## reset func
|
1698 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1699 |
+
|
1700 |
+
## if have a localhost access error, try to use the following code
|
1701 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1702 |
+
# demo.launch()
|
brushedit_app_only_integrate.py
ADDED
@@ -0,0 +1,1725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
# base_openai_url = "https://api.deepseek.com/"
|
39 |
+
|
40 |
+
|
41 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
42 |
+
|
43 |
+
from app.deepseek.instructions import create_apply_editing_messages_deepseek
|
44 |
+
|
45 |
+
|
46 |
+
#### Description ####
|
47 |
+
logo = r"""
|
48 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
49 |
+
"""
|
50 |
+
head = r"""
|
51 |
+
<div style="text-align: center;">
|
52 |
+
<h1> BrushEdit: All-In-One Image Inpainting and Editing</h1>
|
53 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
54 |
+
<a href='https://liyaowei-stu.github.io/project/BrushEdit/'><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
55 |
+
<a href='https://arxiv.org/abs/2412.10316'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
56 |
+
<a href='https://github.com/TencentARC/BrushEdit'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
57 |
+
|
58 |
+
</div>
|
59 |
+
</br>
|
60 |
+
</div>
|
61 |
+
"""
|
62 |
+
descriptions = r"""
|
63 |
+
Demo for CIR"""
|
64 |
+
|
65 |
+
instructions = r"""
|
66 |
+
Demo for CIR"""
|
67 |
+
|
68 |
+
tips = r"""
|
69 |
+
Demo for CIR
|
70 |
+
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
citation = r"""
|
76 |
+
Demo for CIR"""
|
77 |
+
|
78 |
+
# - - - - - examples - - - - - #
|
79 |
+
EXAMPLES = [
|
80 |
+
|
81 |
+
[
|
82 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
83 |
+
"add a magic hat on frog head.",
|
84 |
+
642087011,
|
85 |
+
"frog",
|
86 |
+
"frog",
|
87 |
+
True,
|
88 |
+
False,
|
89 |
+
"GPT4-o (Highly Recommended)"
|
90 |
+
],
|
91 |
+
[
|
92 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
93 |
+
"replace the background to ancient China.",
|
94 |
+
648464818,
|
95 |
+
"chinese_girl",
|
96 |
+
"chinese_girl",
|
97 |
+
True,
|
98 |
+
False,
|
99 |
+
"GPT4-o (Highly Recommended)"
|
100 |
+
],
|
101 |
+
[
|
102 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
103 |
+
"remove the deer.",
|
104 |
+
648464818,
|
105 |
+
"angel_christmas",
|
106 |
+
"angel_christmas",
|
107 |
+
False,
|
108 |
+
False,
|
109 |
+
"GPT4-o (Highly Recommended)"
|
110 |
+
],
|
111 |
+
[
|
112 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
113 |
+
"add a wreath on head.",
|
114 |
+
648464818,
|
115 |
+
"sunflower_girl",
|
116 |
+
"sunflower_girl",
|
117 |
+
True,
|
118 |
+
False,
|
119 |
+
"GPT4-o (Highly Recommended)"
|
120 |
+
],
|
121 |
+
[
|
122 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
123 |
+
"add a butterfly fairy.",
|
124 |
+
648464818,
|
125 |
+
"girl_on_sun",
|
126 |
+
"girl_on_sun",
|
127 |
+
True,
|
128 |
+
False,
|
129 |
+
"GPT4-o (Highly Recommended)"
|
130 |
+
],
|
131 |
+
[
|
132 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
133 |
+
"remove the christmas hat.",
|
134 |
+
642087011,
|
135 |
+
"spider_man_rm",
|
136 |
+
"spider_man_rm",
|
137 |
+
False,
|
138 |
+
False,
|
139 |
+
"GPT4-o (Highly Recommended)"
|
140 |
+
],
|
141 |
+
[
|
142 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
143 |
+
"remove the flower.",
|
144 |
+
642087011,
|
145 |
+
"anime_flower",
|
146 |
+
"anime_flower",
|
147 |
+
False,
|
148 |
+
False,
|
149 |
+
"GPT4-o (Highly Recommended)"
|
150 |
+
],
|
151 |
+
[
|
152 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
153 |
+
"replace the clothes to a delicated floral skirt.",
|
154 |
+
648464818,
|
155 |
+
"chenduling",
|
156 |
+
"chenduling",
|
157 |
+
True,
|
158 |
+
False,
|
159 |
+
"GPT4-o (Highly Recommended)"
|
160 |
+
],
|
161 |
+
[
|
162 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
163 |
+
"make the hedgehog in Italy.",
|
164 |
+
648464818,
|
165 |
+
"hedgehog_rp_bg",
|
166 |
+
"hedgehog_rp_bg",
|
167 |
+
True,
|
168 |
+
False,
|
169 |
+
"GPT4-o (Highly Recommended)"
|
170 |
+
],
|
171 |
+
|
172 |
+
]
|
173 |
+
|
174 |
+
INPUT_IMAGE_PATH = {
|
175 |
+
"frog": "./assets/frog/frog.jpeg",
|
176 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
177 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
178 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
179 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
180 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
181 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
182 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
183 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
184 |
+
}
|
185 |
+
MASK_IMAGE_PATH = {
|
186 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
187 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
188 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
189 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
190 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
191 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
192 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
193 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
194 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
195 |
+
}
|
196 |
+
MASKED_IMAGE_PATH = {
|
197 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
198 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
199 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
200 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
201 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
202 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
203 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
204 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
205 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
206 |
+
}
|
207 |
+
OUTPUT_IMAGE_PATH = {
|
208 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
209 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
210 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
211 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
212 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
213 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
214 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
215 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
216 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
217 |
+
}
|
218 |
+
|
219 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
220 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
221 |
+
|
222 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
223 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
224 |
+
|
225 |
+
|
226 |
+
BASE_MODELS = list(base_models_template.keys())
|
227 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
228 |
+
|
229 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
230 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
231 |
+
|
232 |
+
|
233 |
+
## init device
|
234 |
+
try:
|
235 |
+
if torch.cuda.is_available():
|
236 |
+
device = "cuda"
|
237 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
238 |
+
device = "mps"
|
239 |
+
else:
|
240 |
+
device = "cpu"
|
241 |
+
except:
|
242 |
+
device = "cpu"
|
243 |
+
|
244 |
+
# ## init torch dtype
|
245 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
246 |
+
# torch_dtype = torch.bfloat16
|
247 |
+
# else:
|
248 |
+
# torch_dtype = torch.float16
|
249 |
+
|
250 |
+
# if device == "mps":
|
251 |
+
# torch_dtype = torch.float16
|
252 |
+
|
253 |
+
torch_dtype = torch.float16
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
# download hf models
|
258 |
+
BrushEdit_path = "models/"
|
259 |
+
if not os.path.exists(BrushEdit_path):
|
260 |
+
BrushEdit_path = snapshot_download(
|
261 |
+
repo_id="TencentARC/BrushEdit",
|
262 |
+
local_dir=BrushEdit_path,
|
263 |
+
token=os.getenv("HF_TOKEN"),
|
264 |
+
)
|
265 |
+
|
266 |
+
## init default VLM
|
267 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
268 |
+
if vlm_processor != "" and vlm_model != "":
|
269 |
+
vlm_model.to(device)
|
270 |
+
else:
|
271 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
272 |
+
|
273 |
+
def initialize_llm_model():
|
274 |
+
global llm_model
|
275 |
+
llm_model = OpenAI(api_key="sk-d145b963a92649a88843caeb741e8bbc", base_url="https://api.deepseek.com")
|
276 |
+
|
277 |
+
## init base model
|
278 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
279 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
280 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
281 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
282 |
+
|
283 |
+
|
284 |
+
# input brushnetX ckpt path
|
285 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
286 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
287 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
288 |
+
)
|
289 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
290 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
291 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
292 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
293 |
+
pipe.enable_model_cpu_offload()
|
294 |
+
|
295 |
+
|
296 |
+
## init SAM
|
297 |
+
sam = build_sam(checkpoint=sam_path)
|
298 |
+
sam.to(device=device)
|
299 |
+
sam_predictor = SamPredictor(sam)
|
300 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
301 |
+
|
302 |
+
## init groundingdino_model
|
303 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
304 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
305 |
+
|
306 |
+
## Ordinary function
|
307 |
+
def crop_and_resize(image: Image.Image,
|
308 |
+
target_width: int,
|
309 |
+
target_height: int) -> Image.Image:
|
310 |
+
"""
|
311 |
+
Crops and resizes an image while preserving the aspect ratio.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
315 |
+
target_width (int): Target width of the output image.
|
316 |
+
target_height (int): Target height of the output image.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
Image.Image: Cropped and resized image.
|
320 |
+
"""
|
321 |
+
# Original dimensions
|
322 |
+
original_width, original_height = image.size
|
323 |
+
original_aspect = original_width / original_height
|
324 |
+
target_aspect = target_width / target_height
|
325 |
+
|
326 |
+
# Calculate crop box to maintain aspect ratio
|
327 |
+
if original_aspect > target_aspect:
|
328 |
+
# Crop horizontally
|
329 |
+
new_width = int(original_height * target_aspect)
|
330 |
+
new_height = original_height
|
331 |
+
left = (original_width - new_width) / 2
|
332 |
+
top = 0
|
333 |
+
right = left + new_width
|
334 |
+
bottom = original_height
|
335 |
+
else:
|
336 |
+
# Crop vertically
|
337 |
+
new_width = original_width
|
338 |
+
new_height = int(original_width / target_aspect)
|
339 |
+
left = 0
|
340 |
+
top = (original_height - new_height) / 2
|
341 |
+
right = original_width
|
342 |
+
bottom = top + new_height
|
343 |
+
|
344 |
+
# Crop and resize
|
345 |
+
cropped_image = image.crop((left, top, right, bottom))
|
346 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
347 |
+
return resized_image
|
348 |
+
|
349 |
+
|
350 |
+
## Ordinary function
|
351 |
+
def resize(image: Image.Image,
|
352 |
+
target_width: int,
|
353 |
+
target_height: int) -> Image.Image:
|
354 |
+
"""
|
355 |
+
Crops and resizes an image while preserving the aspect ratio.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
359 |
+
target_width (int): Target width of the output image.
|
360 |
+
target_height (int): Target height of the output image.
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
Image.Image: Cropped and resized image.
|
364 |
+
"""
|
365 |
+
# Original dimensions
|
366 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
367 |
+
return resized_image
|
368 |
+
|
369 |
+
|
370 |
+
def move_mask_func(mask, direction, units):
|
371 |
+
binary_mask = mask.squeeze()>0
|
372 |
+
rows, cols = binary_mask.shape
|
373 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
374 |
+
|
375 |
+
if direction == 'down':
|
376 |
+
# move down
|
377 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
378 |
+
|
379 |
+
elif direction == 'up':
|
380 |
+
# move up
|
381 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
382 |
+
|
383 |
+
elif direction == 'right':
|
384 |
+
# move left
|
385 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
386 |
+
|
387 |
+
elif direction == 'left':
|
388 |
+
# move right
|
389 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
390 |
+
|
391 |
+
return moved_mask
|
392 |
+
|
393 |
+
|
394 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
395 |
+
# Randomly select the size of dilation
|
396 |
+
binary_mask = mask.squeeze()>0
|
397 |
+
|
398 |
+
if dilation_type == 'square_dilation':
|
399 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
400 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
401 |
+
elif dilation_type == 'square_erosion':
|
402 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
403 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
404 |
+
elif dilation_type == 'bounding_box':
|
405 |
+
# find the most left top and left bottom point
|
406 |
+
rows, cols = np.where(binary_mask)
|
407 |
+
if len(rows) == 0 or len(cols) == 0:
|
408 |
+
return mask # return original mask if no valid points
|
409 |
+
|
410 |
+
min_row = np.min(rows)
|
411 |
+
max_row = np.max(rows)
|
412 |
+
min_col = np.min(cols)
|
413 |
+
max_col = np.max(cols)
|
414 |
+
|
415 |
+
# create a bounding box
|
416 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
417 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
418 |
+
|
419 |
+
elif dilation_type == 'bounding_ellipse':
|
420 |
+
# find the most left top and left bottom point
|
421 |
+
rows, cols = np.where(binary_mask)
|
422 |
+
if len(rows) == 0 or len(cols) == 0:
|
423 |
+
return mask # return original mask if no valid points
|
424 |
+
|
425 |
+
min_row = np.min(rows)
|
426 |
+
max_row = np.max(rows)
|
427 |
+
min_col = np.min(cols)
|
428 |
+
max_col = np.max(cols)
|
429 |
+
|
430 |
+
# calculate the center and axis length of the ellipse
|
431 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
432 |
+
a = (max_col - min_col) // 2 # half long axis
|
433 |
+
b = (max_row - min_row) // 2 # half short axis
|
434 |
+
|
435 |
+
# create a bounding ellipse
|
436 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
437 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
438 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
439 |
+
dilated_mask[ellipse_mask] = True
|
440 |
+
else:
|
441 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
442 |
+
|
443 |
+
# use binary dilation
|
444 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
445 |
+
return dilated_mask
|
446 |
+
|
447 |
+
|
448 |
+
## Gradio component function
|
449 |
+
def update_vlm_model(vlm_name):
|
450 |
+
global vlm_model, vlm_processor
|
451 |
+
if vlm_model is not None:
|
452 |
+
del vlm_model
|
453 |
+
torch.cuda.empty_cache()
|
454 |
+
|
455 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
456 |
+
|
457 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
458 |
+
if vlm_type == "llava-next":
|
459 |
+
if vlm_processor != "" and vlm_model != "":
|
460 |
+
vlm_model.to(device)
|
461 |
+
return vlm_model_dropdown
|
462 |
+
else:
|
463 |
+
if os.path.exists(vlm_local_path):
|
464 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
465 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
466 |
+
else:
|
467 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
468 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
469 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
470 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
471 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
472 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
473 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
474 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
475 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
476 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
477 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
478 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
479 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
480 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
481 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
482 |
+
elif vlm_type == "qwen2-vl":
|
483 |
+
if vlm_processor != "" and vlm_model != "":
|
484 |
+
vlm_model.to(device)
|
485 |
+
return vlm_model_dropdown
|
486 |
+
else:
|
487 |
+
if os.path.exists(vlm_local_path):
|
488 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
489 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
490 |
+
else:
|
491 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
492 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
493 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
494 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
495 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
496 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
497 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
498 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
499 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
500 |
+
elif vlm_type == "openai":
|
501 |
+
pass
|
502 |
+
return "success"
|
503 |
+
|
504 |
+
|
505 |
+
def update_base_model(base_model_name):
|
506 |
+
global pipe
|
507 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
508 |
+
if pipe is not None:
|
509 |
+
del pipe
|
510 |
+
torch.cuda.empty_cache()
|
511 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
512 |
+
if pipe != "":
|
513 |
+
pipe.to(device)
|
514 |
+
else:
|
515 |
+
if os.path.exists(base_model_path):
|
516 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
517 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
518 |
+
)
|
519 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
520 |
+
pipe.enable_model_cpu_offload()
|
521 |
+
else:
|
522 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
523 |
+
return "success"
|
524 |
+
|
525 |
+
|
526 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
527 |
+
global vlm_model, vlm_processor
|
528 |
+
if vlm_model is not None:
|
529 |
+
del vlm_model
|
530 |
+
torch.cuda.empty_cache()
|
531 |
+
try:
|
532 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
|
533 |
+
vlm_processor = ""
|
534 |
+
response = vlm_model.chat.completions.create(
|
535 |
+
model="deepseek-chat",
|
536 |
+
messages=[
|
537 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
538 |
+
{"role": "user", "content": "Hello."}
|
539 |
+
]
|
540 |
+
)
|
541 |
+
response_str = response.choices[0].message.content
|
542 |
+
|
543 |
+
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
|
544 |
+
except Exception as e:
|
545 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
546 |
+
|
547 |
+
|
548 |
+
|
549 |
+
def process(input_image,
|
550 |
+
original_image,
|
551 |
+
original_mask,
|
552 |
+
prompt,
|
553 |
+
negative_prompt,
|
554 |
+
control_strength,
|
555 |
+
seed,
|
556 |
+
randomize_seed,
|
557 |
+
guidance_scale,
|
558 |
+
num_inference_steps,
|
559 |
+
num_samples,
|
560 |
+
blending,
|
561 |
+
category,
|
562 |
+
target_prompt,
|
563 |
+
resize_default,
|
564 |
+
aspect_ratio_name,
|
565 |
+
invert_mask_state):
|
566 |
+
if original_image is None:
|
567 |
+
if input_image is None:
|
568 |
+
raise gr.Error('Please upload the input image')
|
569 |
+
else:
|
570 |
+
print("input_image的键:", input_image.keys()) # 打印字典键
|
571 |
+
image_pil = input_image["background"].convert("RGB")
|
572 |
+
original_image = np.array(image_pil)
|
573 |
+
if prompt is None or prompt == "":
|
574 |
+
if target_prompt is None or target_prompt == "":
|
575 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
576 |
+
|
577 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
578 |
+
input_mask = np.asarray(alpha_mask)
|
579 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
580 |
+
if output_w == "" or output_h == "":
|
581 |
+
output_h, output_w = original_image.shape[:2]
|
582 |
+
|
583 |
+
if resize_default:
|
584 |
+
short_side = min(output_w, output_h)
|
585 |
+
scale_ratio = 640 / short_side
|
586 |
+
output_w = int(output_w * scale_ratio)
|
587 |
+
output_h = int(output_h * scale_ratio)
|
588 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
589 |
+
original_image = np.array(original_image)
|
590 |
+
if input_mask is not None:
|
591 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
592 |
+
input_mask = np.array(input_mask)
|
593 |
+
if original_mask is not None:
|
594 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
595 |
+
original_mask = np.array(original_mask)
|
596 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
597 |
+
else:
|
598 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
599 |
+
pass
|
600 |
+
else:
|
601 |
+
if resize_default:
|
602 |
+
short_side = min(output_w, output_h)
|
603 |
+
scale_ratio = 640 / short_side
|
604 |
+
output_w = int(output_w * scale_ratio)
|
605 |
+
output_h = int(output_h * scale_ratio)
|
606 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
607 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
608 |
+
original_image = np.array(original_image)
|
609 |
+
if input_mask is not None:
|
610 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
611 |
+
input_mask = np.array(input_mask)
|
612 |
+
if original_mask is not None:
|
613 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
614 |
+
original_mask = np.array(original_mask)
|
615 |
+
|
616 |
+
if invert_mask_state:
|
617 |
+
original_mask = original_mask
|
618 |
+
else:
|
619 |
+
if input_mask.max() == 0:
|
620 |
+
original_mask = original_mask
|
621 |
+
else:
|
622 |
+
original_mask = input_mask
|
623 |
+
|
624 |
+
|
625 |
+
## inpainting directly if target_prompt is not None
|
626 |
+
if category is not None:
|
627 |
+
pass
|
628 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
629 |
+
pass
|
630 |
+
else:
|
631 |
+
try:
|
632 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
633 |
+
except Exception as e:
|
634 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
635 |
+
|
636 |
+
|
637 |
+
if original_mask is not None:
|
638 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
639 |
+
else:
|
640 |
+
try:
|
641 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
642 |
+
vlm_processor,
|
643 |
+
vlm_model,
|
644 |
+
original_image,
|
645 |
+
category,
|
646 |
+
prompt,
|
647 |
+
device)
|
648 |
+
|
649 |
+
original_mask = vlm_response_mask(vlm_processor,
|
650 |
+
vlm_model,
|
651 |
+
category,
|
652 |
+
original_image,
|
653 |
+
prompt,
|
654 |
+
object_wait_for_edit,
|
655 |
+
sam,
|
656 |
+
sam_predictor,
|
657 |
+
sam_automask_generator,
|
658 |
+
groundingdino_model,
|
659 |
+
device).astype(np.uint8)
|
660 |
+
except Exception as e:
|
661 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
662 |
+
|
663 |
+
if original_mask.ndim == 2:
|
664 |
+
original_mask = original_mask[:,:,None]
|
665 |
+
|
666 |
+
|
667 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
668 |
+
prompt_after_apply_instruction = target_prompt
|
669 |
+
|
670 |
+
else:
|
671 |
+
try:
|
672 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
673 |
+
vlm_processor,
|
674 |
+
vlm_model,
|
675 |
+
original_image,
|
676 |
+
prompt,
|
677 |
+
device)
|
678 |
+
except Exception as e:
|
679 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
680 |
+
|
681 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
682 |
+
|
683 |
+
|
684 |
+
with torch.autocast(device):
|
685 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
686 |
+
prompt_after_apply_instruction,
|
687 |
+
original_mask,
|
688 |
+
original_image,
|
689 |
+
generator,
|
690 |
+
num_inference_steps,
|
691 |
+
guidance_scale,
|
692 |
+
control_strength,
|
693 |
+
negative_prompt,
|
694 |
+
num_samples,
|
695 |
+
blending)
|
696 |
+
original_image = np.array(init_image_np)
|
697 |
+
masked_image = original_image * (1 - (mask_np>0))
|
698 |
+
masked_image = masked_image.astype(np.uint8)
|
699 |
+
masked_image = Image.fromarray(masked_image)
|
700 |
+
# Save the images (optional)
|
701 |
+
# import uuid
|
702 |
+
# uuid = str(uuid.uuid4())
|
703 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
704 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
705 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
706 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
707 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
708 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
709 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
710 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
711 |
+
|
712 |
+
|
713 |
+
def process_mask(input_image,
|
714 |
+
original_image,
|
715 |
+
prompt,
|
716 |
+
resize_default,
|
717 |
+
aspect_ratio_name):
|
718 |
+
if original_image is None:
|
719 |
+
raise gr.Error('Please upload the input image')
|
720 |
+
if prompt is None:
|
721 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
722 |
+
|
723 |
+
## load mask
|
724 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
725 |
+
input_mask = np.array(alpha_mask)
|
726 |
+
|
727 |
+
# load example image
|
728 |
+
if isinstance(original_image, str):
|
729 |
+
original_image = input_image["background"]
|
730 |
+
|
731 |
+
if input_mask.max() == 0:
|
732 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
733 |
+
|
734 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
735 |
+
vlm_model,
|
736 |
+
original_image,
|
737 |
+
category,
|
738 |
+
prompt,
|
739 |
+
device)
|
740 |
+
# original mask: h,w,1 [0, 255]
|
741 |
+
original_mask = vlm_response_mask(
|
742 |
+
vlm_processor,
|
743 |
+
vlm_model,
|
744 |
+
category,
|
745 |
+
original_image,
|
746 |
+
prompt,
|
747 |
+
object_wait_for_edit,
|
748 |
+
sam,
|
749 |
+
sam_predictor,
|
750 |
+
sam_automask_generator,
|
751 |
+
groundingdino_model,
|
752 |
+
device).astype(np.uint8)
|
753 |
+
else:
|
754 |
+
original_mask = input_mask.astype(np.uint8)
|
755 |
+
category = None
|
756 |
+
|
757 |
+
## resize mask if needed
|
758 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
759 |
+
if output_w == "" or output_h == "":
|
760 |
+
output_h, output_w = original_image.shape[:2]
|
761 |
+
if resize_default:
|
762 |
+
short_side = min(output_w, output_h)
|
763 |
+
scale_ratio = 640 / short_side
|
764 |
+
output_w = int(output_w * scale_ratio)
|
765 |
+
output_h = int(output_h * scale_ratio)
|
766 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
767 |
+
original_image = np.array(original_image)
|
768 |
+
if input_mask is not None:
|
769 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
770 |
+
input_mask = np.array(input_mask)
|
771 |
+
if original_mask is not None:
|
772 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
773 |
+
original_mask = np.array(original_mask)
|
774 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
775 |
+
else:
|
776 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
777 |
+
pass
|
778 |
+
else:
|
779 |
+
if resize_default:
|
780 |
+
short_side = min(output_w, output_h)
|
781 |
+
scale_ratio = 640 / short_side
|
782 |
+
output_w = int(output_w * scale_ratio)
|
783 |
+
output_h = int(output_h * scale_ratio)
|
784 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
785 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
786 |
+
original_image = np.array(original_image)
|
787 |
+
if input_mask is not None:
|
788 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
789 |
+
input_mask = np.array(input_mask)
|
790 |
+
if original_mask is not None:
|
791 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
792 |
+
original_mask = np.array(original_mask)
|
793 |
+
|
794 |
+
|
795 |
+
if original_mask.ndim == 2:
|
796 |
+
original_mask = original_mask[:,:,None]
|
797 |
+
|
798 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
799 |
+
|
800 |
+
masked_image = original_image * (1 - (original_mask>0))
|
801 |
+
masked_image = masked_image.astype(np.uint8)
|
802 |
+
masked_image = Image.fromarray(masked_image)
|
803 |
+
|
804 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
805 |
+
|
806 |
+
|
807 |
+
def process_random_mask(input_image,
|
808 |
+
original_image,
|
809 |
+
original_mask,
|
810 |
+
resize_default,
|
811 |
+
aspect_ratio_name,
|
812 |
+
):
|
813 |
+
|
814 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
815 |
+
input_mask = np.asarray(alpha_mask)
|
816 |
+
|
817 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
818 |
+
if output_w == "" or output_h == "":
|
819 |
+
output_h, output_w = original_image.shape[:2]
|
820 |
+
if resize_default:
|
821 |
+
short_side = min(output_w, output_h)
|
822 |
+
scale_ratio = 640 / short_side
|
823 |
+
output_w = int(output_w * scale_ratio)
|
824 |
+
output_h = int(output_h * scale_ratio)
|
825 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
826 |
+
original_image = np.array(original_image)
|
827 |
+
if input_mask is not None:
|
828 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
829 |
+
input_mask = np.array(input_mask)
|
830 |
+
if original_mask is not None:
|
831 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
832 |
+
original_mask = np.array(original_mask)
|
833 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
834 |
+
else:
|
835 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
836 |
+
pass
|
837 |
+
else:
|
838 |
+
if resize_default:
|
839 |
+
short_side = min(output_w, output_h)
|
840 |
+
scale_ratio = 640 / short_side
|
841 |
+
output_w = int(output_w * scale_ratio)
|
842 |
+
output_h = int(output_h * scale_ratio)
|
843 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
844 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
845 |
+
original_image = np.array(original_image)
|
846 |
+
if input_mask is not None:
|
847 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
848 |
+
input_mask = np.array(input_mask)
|
849 |
+
if original_mask is not None:
|
850 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
851 |
+
original_mask = np.array(original_mask)
|
852 |
+
|
853 |
+
|
854 |
+
if input_mask.max() == 0:
|
855 |
+
original_mask = original_mask
|
856 |
+
else:
|
857 |
+
original_mask = input_mask
|
858 |
+
|
859 |
+
if original_mask is None:
|
860 |
+
raise gr.Error('Please generate mask first')
|
861 |
+
|
862 |
+
if original_mask.ndim == 2:
|
863 |
+
original_mask = original_mask[:,:,None]
|
864 |
+
|
865 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
866 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
867 |
+
|
868 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
869 |
+
|
870 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
871 |
+
masked_image = masked_image.astype(original_image.dtype)
|
872 |
+
masked_image = Image.fromarray(masked_image)
|
873 |
+
|
874 |
+
|
875 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
876 |
+
|
877 |
+
|
878 |
+
def process_dilation_mask(input_image,
|
879 |
+
original_image,
|
880 |
+
original_mask,
|
881 |
+
resize_default,
|
882 |
+
aspect_ratio_name,
|
883 |
+
dilation_size=20):
|
884 |
+
|
885 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
886 |
+
input_mask = np.asarray(alpha_mask)
|
887 |
+
|
888 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
889 |
+
if output_w == "" or output_h == "":
|
890 |
+
output_h, output_w = original_image.shape[:2]
|
891 |
+
if resize_default:
|
892 |
+
short_side = min(output_w, output_h)
|
893 |
+
scale_ratio = 640 / short_side
|
894 |
+
output_w = int(output_w * scale_ratio)
|
895 |
+
output_h = int(output_h * scale_ratio)
|
896 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
897 |
+
original_image = np.array(original_image)
|
898 |
+
if input_mask is not None:
|
899 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
900 |
+
input_mask = np.array(input_mask)
|
901 |
+
if original_mask is not None:
|
902 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
903 |
+
original_mask = np.array(original_mask)
|
904 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
905 |
+
else:
|
906 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
907 |
+
pass
|
908 |
+
else:
|
909 |
+
if resize_default:
|
910 |
+
short_side = min(output_w, output_h)
|
911 |
+
scale_ratio = 640 / short_side
|
912 |
+
output_w = int(output_w * scale_ratio)
|
913 |
+
output_h = int(output_h * scale_ratio)
|
914 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
915 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
916 |
+
original_image = np.array(original_image)
|
917 |
+
if input_mask is not None:
|
918 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
919 |
+
input_mask = np.array(input_mask)
|
920 |
+
if original_mask is not None:
|
921 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
922 |
+
original_mask = np.array(original_mask)
|
923 |
+
|
924 |
+
if input_mask.max() == 0:
|
925 |
+
original_mask = original_mask
|
926 |
+
else:
|
927 |
+
original_mask = input_mask
|
928 |
+
|
929 |
+
if original_mask is None:
|
930 |
+
raise gr.Error('Please generate mask first')
|
931 |
+
|
932 |
+
if original_mask.ndim == 2:
|
933 |
+
original_mask = original_mask[:,:,None]
|
934 |
+
|
935 |
+
dilation_type = np.random.choice(['square_dilation'])
|
936 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
937 |
+
|
938 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
939 |
+
|
940 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
941 |
+
masked_image = masked_image.astype(original_image.dtype)
|
942 |
+
masked_image = Image.fromarray(masked_image)
|
943 |
+
|
944 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
945 |
+
|
946 |
+
|
947 |
+
def process_erosion_mask(input_image,
|
948 |
+
original_image,
|
949 |
+
original_mask,
|
950 |
+
resize_default,
|
951 |
+
aspect_ratio_name,
|
952 |
+
dilation_size=20):
|
953 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
954 |
+
input_mask = np.asarray(alpha_mask)
|
955 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
956 |
+
if output_w == "" or output_h == "":
|
957 |
+
output_h, output_w = original_image.shape[:2]
|
958 |
+
if resize_default:
|
959 |
+
short_side = min(output_w, output_h)
|
960 |
+
scale_ratio = 640 / short_side
|
961 |
+
output_w = int(output_w * scale_ratio)
|
962 |
+
output_h = int(output_h * scale_ratio)
|
963 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
964 |
+
original_image = np.array(original_image)
|
965 |
+
if input_mask is not None:
|
966 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
967 |
+
input_mask = np.array(input_mask)
|
968 |
+
if original_mask is not None:
|
969 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
970 |
+
original_mask = np.array(original_mask)
|
971 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
972 |
+
else:
|
973 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
974 |
+
pass
|
975 |
+
else:
|
976 |
+
if resize_default:
|
977 |
+
short_side = min(output_w, output_h)
|
978 |
+
scale_ratio = 640 / short_side
|
979 |
+
output_w = int(output_w * scale_ratio)
|
980 |
+
output_h = int(output_h * scale_ratio)
|
981 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
982 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
983 |
+
original_image = np.array(original_image)
|
984 |
+
if input_mask is not None:
|
985 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
986 |
+
input_mask = np.array(input_mask)
|
987 |
+
if original_mask is not None:
|
988 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
989 |
+
original_mask = np.array(original_mask)
|
990 |
+
|
991 |
+
if input_mask.max() == 0:
|
992 |
+
original_mask = original_mask
|
993 |
+
else:
|
994 |
+
original_mask = input_mask
|
995 |
+
|
996 |
+
if original_mask is None:
|
997 |
+
raise gr.Error('Please generate mask first')
|
998 |
+
|
999 |
+
if original_mask.ndim == 2:
|
1000 |
+
original_mask = original_mask[:,:,None]
|
1001 |
+
|
1002 |
+
dilation_type = np.random.choice(['square_erosion'])
|
1003 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
1004 |
+
|
1005 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
1006 |
+
|
1007 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
1008 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1009 |
+
masked_image = Image.fromarray(masked_image)
|
1010 |
+
|
1011 |
+
|
1012 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
1013 |
+
|
1014 |
+
|
1015 |
+
def move_mask_left(input_image,
|
1016 |
+
original_image,
|
1017 |
+
original_mask,
|
1018 |
+
moving_pixels,
|
1019 |
+
resize_default,
|
1020 |
+
aspect_ratio_name):
|
1021 |
+
|
1022 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1023 |
+
input_mask = np.asarray(alpha_mask)
|
1024 |
+
|
1025 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1026 |
+
if output_w == "" or output_h == "":
|
1027 |
+
output_h, output_w = original_image.shape[:2]
|
1028 |
+
if resize_default:
|
1029 |
+
short_side = min(output_w, output_h)
|
1030 |
+
scale_ratio = 640 / short_side
|
1031 |
+
output_w = int(output_w * scale_ratio)
|
1032 |
+
output_h = int(output_h * scale_ratio)
|
1033 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1034 |
+
original_image = np.array(original_image)
|
1035 |
+
if input_mask is not None:
|
1036 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1037 |
+
input_mask = np.array(input_mask)
|
1038 |
+
if original_mask is not None:
|
1039 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1040 |
+
original_mask = np.array(original_mask)
|
1041 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1042 |
+
else:
|
1043 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1044 |
+
pass
|
1045 |
+
else:
|
1046 |
+
if resize_default:
|
1047 |
+
short_side = min(output_w, output_h)
|
1048 |
+
scale_ratio = 640 / short_side
|
1049 |
+
output_w = int(output_w * scale_ratio)
|
1050 |
+
output_h = int(output_h * scale_ratio)
|
1051 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1052 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1053 |
+
original_image = np.array(original_image)
|
1054 |
+
if input_mask is not None:
|
1055 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1056 |
+
input_mask = np.array(input_mask)
|
1057 |
+
if original_mask is not None:
|
1058 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1059 |
+
original_mask = np.array(original_mask)
|
1060 |
+
|
1061 |
+
if input_mask.max() == 0:
|
1062 |
+
original_mask = original_mask
|
1063 |
+
else:
|
1064 |
+
original_mask = input_mask
|
1065 |
+
|
1066 |
+
if original_mask is None:
|
1067 |
+
raise gr.Error('Please generate mask first')
|
1068 |
+
|
1069 |
+
if original_mask.ndim == 2:
|
1070 |
+
original_mask = original_mask[:,:,None]
|
1071 |
+
|
1072 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1073 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1074 |
+
|
1075 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1076 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1077 |
+
masked_image = Image.fromarray(masked_image)
|
1078 |
+
|
1079 |
+
if moved_mask.max() <= 1:
|
1080 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1081 |
+
original_mask = moved_mask
|
1082 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1083 |
+
|
1084 |
+
|
1085 |
+
def move_mask_right(input_image,
|
1086 |
+
original_image,
|
1087 |
+
original_mask,
|
1088 |
+
moving_pixels,
|
1089 |
+
resize_default,
|
1090 |
+
aspect_ratio_name):
|
1091 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1092 |
+
input_mask = np.asarray(alpha_mask)
|
1093 |
+
|
1094 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1095 |
+
if output_w == "" or output_h == "":
|
1096 |
+
output_h, output_w = original_image.shape[:2]
|
1097 |
+
if resize_default:
|
1098 |
+
short_side = min(output_w, output_h)
|
1099 |
+
scale_ratio = 640 / short_side
|
1100 |
+
output_w = int(output_w * scale_ratio)
|
1101 |
+
output_h = int(output_h * scale_ratio)
|
1102 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1103 |
+
original_image = np.array(original_image)
|
1104 |
+
if input_mask is not None:
|
1105 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1106 |
+
input_mask = np.array(input_mask)
|
1107 |
+
if original_mask is not None:
|
1108 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1109 |
+
original_mask = np.array(original_mask)
|
1110 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1111 |
+
else:
|
1112 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1113 |
+
pass
|
1114 |
+
else:
|
1115 |
+
if resize_default:
|
1116 |
+
short_side = min(output_w, output_h)
|
1117 |
+
scale_ratio = 640 / short_side
|
1118 |
+
output_w = int(output_w * scale_ratio)
|
1119 |
+
output_h = int(output_h * scale_ratio)
|
1120 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1121 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1122 |
+
original_image = np.array(original_image)
|
1123 |
+
if input_mask is not None:
|
1124 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1125 |
+
input_mask = np.array(input_mask)
|
1126 |
+
if original_mask is not None:
|
1127 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1128 |
+
original_mask = np.array(original_mask)
|
1129 |
+
|
1130 |
+
if input_mask.max() == 0:
|
1131 |
+
original_mask = original_mask
|
1132 |
+
else:
|
1133 |
+
original_mask = input_mask
|
1134 |
+
|
1135 |
+
if original_mask is None:
|
1136 |
+
raise gr.Error('Please generate mask first')
|
1137 |
+
|
1138 |
+
if original_mask.ndim == 2:
|
1139 |
+
original_mask = original_mask[:,:,None]
|
1140 |
+
|
1141 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1142 |
+
|
1143 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1144 |
+
|
1145 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1146 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1147 |
+
masked_image = Image.fromarray(masked_image)
|
1148 |
+
|
1149 |
+
|
1150 |
+
if moved_mask.max() <= 1:
|
1151 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1152 |
+
original_mask = moved_mask
|
1153 |
+
|
1154 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1155 |
+
|
1156 |
+
|
1157 |
+
def move_mask_up(input_image,
|
1158 |
+
original_image,
|
1159 |
+
original_mask,
|
1160 |
+
moving_pixels,
|
1161 |
+
resize_default,
|
1162 |
+
aspect_ratio_name):
|
1163 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1164 |
+
input_mask = np.asarray(alpha_mask)
|
1165 |
+
|
1166 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1167 |
+
if output_w == "" or output_h == "":
|
1168 |
+
output_h, output_w = original_image.shape[:2]
|
1169 |
+
if resize_default:
|
1170 |
+
short_side = min(output_w, output_h)
|
1171 |
+
scale_ratio = 640 / short_side
|
1172 |
+
output_w = int(output_w * scale_ratio)
|
1173 |
+
output_h = int(output_h * scale_ratio)
|
1174 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1175 |
+
original_image = np.array(original_image)
|
1176 |
+
if input_mask is not None:
|
1177 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1178 |
+
input_mask = np.array(input_mask)
|
1179 |
+
if original_mask is not None:
|
1180 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1181 |
+
original_mask = np.array(original_mask)
|
1182 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1183 |
+
else:
|
1184 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1185 |
+
pass
|
1186 |
+
else:
|
1187 |
+
if resize_default:
|
1188 |
+
short_side = min(output_w, output_h)
|
1189 |
+
scale_ratio = 640 / short_side
|
1190 |
+
output_w = int(output_w * scale_ratio)
|
1191 |
+
output_h = int(output_h * scale_ratio)
|
1192 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1193 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1194 |
+
original_image = np.array(original_image)
|
1195 |
+
if input_mask is not None:
|
1196 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1197 |
+
input_mask = np.array(input_mask)
|
1198 |
+
if original_mask is not None:
|
1199 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1200 |
+
original_mask = np.array(original_mask)
|
1201 |
+
|
1202 |
+
if input_mask.max() == 0:
|
1203 |
+
original_mask = original_mask
|
1204 |
+
else:
|
1205 |
+
original_mask = input_mask
|
1206 |
+
|
1207 |
+
if original_mask is None:
|
1208 |
+
raise gr.Error('Please generate mask first')
|
1209 |
+
|
1210 |
+
if original_mask.ndim == 2:
|
1211 |
+
original_mask = original_mask[:,:,None]
|
1212 |
+
|
1213 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1214 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1215 |
+
|
1216 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1217 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1218 |
+
masked_image = Image.fromarray(masked_image)
|
1219 |
+
|
1220 |
+
if moved_mask.max() <= 1:
|
1221 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1222 |
+
original_mask = moved_mask
|
1223 |
+
|
1224 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1225 |
+
|
1226 |
+
|
1227 |
+
def move_mask_down(input_image,
|
1228 |
+
original_image,
|
1229 |
+
original_mask,
|
1230 |
+
moving_pixels,
|
1231 |
+
resize_default,
|
1232 |
+
aspect_ratio_name):
|
1233 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1234 |
+
input_mask = np.asarray(alpha_mask)
|
1235 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1236 |
+
if output_w == "" or output_h == "":
|
1237 |
+
output_h, output_w = original_image.shape[:2]
|
1238 |
+
if resize_default:
|
1239 |
+
short_side = min(output_w, output_h)
|
1240 |
+
scale_ratio = 640 / short_side
|
1241 |
+
output_w = int(output_w * scale_ratio)
|
1242 |
+
output_h = int(output_h * scale_ratio)
|
1243 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1244 |
+
original_image = np.array(original_image)
|
1245 |
+
if input_mask is not None:
|
1246 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1247 |
+
input_mask = np.array(input_mask)
|
1248 |
+
if original_mask is not None:
|
1249 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1250 |
+
original_mask = np.array(original_mask)
|
1251 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1252 |
+
else:
|
1253 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1254 |
+
pass
|
1255 |
+
else:
|
1256 |
+
if resize_default:
|
1257 |
+
short_side = min(output_w, output_h)
|
1258 |
+
scale_ratio = 640 / short_side
|
1259 |
+
output_w = int(output_w * scale_ratio)
|
1260 |
+
output_h = int(output_h * scale_ratio)
|
1261 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1262 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1263 |
+
original_image = np.array(original_image)
|
1264 |
+
if input_mask is not None:
|
1265 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1266 |
+
input_mask = np.array(input_mask)
|
1267 |
+
if original_mask is not None:
|
1268 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1269 |
+
original_mask = np.array(original_mask)
|
1270 |
+
|
1271 |
+
if input_mask.max() == 0:
|
1272 |
+
original_mask = original_mask
|
1273 |
+
else:
|
1274 |
+
original_mask = input_mask
|
1275 |
+
|
1276 |
+
if original_mask is None:
|
1277 |
+
raise gr.Error('Please generate mask first')
|
1278 |
+
|
1279 |
+
if original_mask.ndim == 2:
|
1280 |
+
original_mask = original_mask[:,:,None]
|
1281 |
+
|
1282 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1283 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1284 |
+
|
1285 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1286 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1287 |
+
masked_image = Image.fromarray(masked_image)
|
1288 |
+
|
1289 |
+
if moved_mask.max() <= 1:
|
1290 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1291 |
+
original_mask = moved_mask
|
1292 |
+
|
1293 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1294 |
+
|
1295 |
+
|
1296 |
+
def invert_mask(input_image,
|
1297 |
+
original_image,
|
1298 |
+
original_mask,
|
1299 |
+
):
|
1300 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1301 |
+
input_mask = np.asarray(alpha_mask)
|
1302 |
+
if input_mask.max() == 0:
|
1303 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1304 |
+
else:
|
1305 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1306 |
+
|
1307 |
+
if original_mask is None:
|
1308 |
+
raise gr.Error('Please generate mask first')
|
1309 |
+
|
1310 |
+
original_mask = original_mask.squeeze()
|
1311 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1312 |
+
|
1313 |
+
if original_mask.ndim == 2:
|
1314 |
+
original_mask = original_mask[:,:,None]
|
1315 |
+
|
1316 |
+
if original_mask.max() <= 1:
|
1317 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1318 |
+
|
1319 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1320 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1321 |
+
masked_image = Image.fromarray(masked_image)
|
1322 |
+
|
1323 |
+
return [masked_image], [mask_image], original_mask, True
|
1324 |
+
|
1325 |
+
|
1326 |
+
def init_img(base,
|
1327 |
+
init_type,
|
1328 |
+
prompt,
|
1329 |
+
aspect_ratio,
|
1330 |
+
example_change_times
|
1331 |
+
):
|
1332 |
+
image_pil = base["background"].convert("RGB")
|
1333 |
+
original_image = np.array(image_pil)
|
1334 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1335 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1336 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1337 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1338 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1339 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1340 |
+
width, height = image_pil.size
|
1341 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1342 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1343 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1344 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1345 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1346 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1347 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1348 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1349 |
+
else:
|
1350 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1351 |
+
aspect_ratio = "Custom resolution"
|
1352 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1353 |
+
|
1354 |
+
|
1355 |
+
def reset_func(input_image,
|
1356 |
+
original_image,
|
1357 |
+
original_mask,
|
1358 |
+
prompt,
|
1359 |
+
target_prompt,
|
1360 |
+
):
|
1361 |
+
input_image = None
|
1362 |
+
original_image = None
|
1363 |
+
original_mask = None
|
1364 |
+
prompt = ''
|
1365 |
+
mask_gallery = []
|
1366 |
+
masked_gallery = []
|
1367 |
+
result_gallery = []
|
1368 |
+
target_prompt = ''
|
1369 |
+
if torch.cuda.is_available():
|
1370 |
+
torch.cuda.empty_cache()
|
1371 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1372 |
+
|
1373 |
+
|
1374 |
+
def update_example(example_type,
|
1375 |
+
prompt,
|
1376 |
+
example_change_times):
|
1377 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1378 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1379 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1380 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1381 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1382 |
+
width, height = image_pil.size
|
1383 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1384 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1385 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1386 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1387 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1388 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1389 |
+
|
1390 |
+
original_image = np.array(image_pil)
|
1391 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1392 |
+
aspect_ratio = "Custom resolution"
|
1393 |
+
example_change_times += 1
|
1394 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1395 |
+
|
1396 |
+
|
1397 |
+
def generate_target_prompt(input_image,
|
1398 |
+
original_image,
|
1399 |
+
prompt):
|
1400 |
+
# load example image
|
1401 |
+
if isinstance(original_image, str):
|
1402 |
+
original_image = input_image
|
1403 |
+
|
1404 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1405 |
+
vlm_processor,
|
1406 |
+
vlm_model,
|
1407 |
+
original_image,
|
1408 |
+
prompt,
|
1409 |
+
device)
|
1410 |
+
return prompt_after_apply_instruction
|
1411 |
+
|
1412 |
+
|
1413 |
+
# 新增事件处理函数
|
1414 |
+
def generate_blip_description(input_image):
|
1415 |
+
if input_image is None:
|
1416 |
+
return "", "Input image cannot be None"
|
1417 |
+
from app.utils.utils import generate_caption
|
1418 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
1419 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
|
1420 |
+
try:
|
1421 |
+
image_pil = input_image["background"].convert("RGB")
|
1422 |
+
except KeyError:
|
1423 |
+
return "", "Input image missing 'background' key"
|
1424 |
+
except AttributeError as e:
|
1425 |
+
return "", f"Invalid image object: {str(e)}"
|
1426 |
+
try:
|
1427 |
+
description = generate_caption(blip_processor, blip_model, image_pil, device)
|
1428 |
+
return description, description # 同时更新state和显示组件
|
1429 |
+
except Exception as e:
|
1430 |
+
return "", f"Caption generation failed: {str(e)}"
|
1431 |
+
|
1432 |
+
|
1433 |
+
def verify_deepseek_api():
|
1434 |
+
try:
|
1435 |
+
initialize_llm_model()
|
1436 |
+
response = llm_model.chat.completions.create(
|
1437 |
+
model="deepseek-chat",
|
1438 |
+
messages=[
|
1439 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1440 |
+
{"role": "user", "content": "Hello."}
|
1441 |
+
]
|
1442 |
+
)
|
1443 |
+
return True
|
1444 |
+
except Exception as e:
|
1445 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
1446 |
+
|
1447 |
+
|
1448 |
+
def llm_response_prompt_after_apply_instruction(image_caption, editing_prompt):
|
1449 |
+
try:
|
1450 |
+
initialize_llm_model()
|
1451 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
1452 |
+
response = llm_model.chat.completions.create(
|
1453 |
+
model="deepseek-chat",
|
1454 |
+
messages=messages
|
1455 |
+
)
|
1456 |
+
response_str = response.choices[0].message.content
|
1457 |
+
return response_str
|
1458 |
+
except Exception as e:
|
1459 |
+
raise gr.Error(f"未预期错误: {str(e)},请检查控制台日志获取详细信息")
|
1460 |
+
|
1461 |
+
|
1462 |
+
def enhance_description(prompt, blip_description):
|
1463 |
+
try:
|
1464 |
+
initialize_llm_model()
|
1465 |
+
|
1466 |
+
if not prompt or not blip_description:
|
1467 |
+
print("Empty input detected")
|
1468 |
+
return "", ""
|
1469 |
+
|
1470 |
+
print(f"Enhancing with prompt: {prompt}")
|
1471 |
+
description = llm_response_prompt_after_apply_instruction(blip_description, prompt)
|
1472 |
+
return description, description
|
1473 |
+
|
1474 |
+
except Exception as e:
|
1475 |
+
print(f"Enhancement failed: {str(e)}")
|
1476 |
+
return "Error occurred", "Error occurred"
|
1477 |
+
|
1478 |
+
|
1479 |
+
block = gr.Blocks(
|
1480 |
+
theme=gr.themes.Soft(
|
1481 |
+
radius_size=gr.themes.sizes.radius_none,
|
1482 |
+
text_size=gr.themes.sizes.text_md
|
1483 |
+
)
|
1484 |
+
)
|
1485 |
+
with block as demo:
|
1486 |
+
with gr.Row():
|
1487 |
+
with gr.Column():
|
1488 |
+
gr.HTML(head)
|
1489 |
+
|
1490 |
+
gr.Markdown(descriptions)
|
1491 |
+
|
1492 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1493 |
+
with gr.Row(equal_height=True):
|
1494 |
+
gr.Markdown(instructions)
|
1495 |
+
|
1496 |
+
original_image = gr.State(value=None)
|
1497 |
+
original_mask = gr.State(value=None)
|
1498 |
+
category = gr.State(value=None)
|
1499 |
+
status = gr.State(value=None)
|
1500 |
+
invert_mask_state = gr.State(value=False)
|
1501 |
+
example_change_times = gr.State(value=0)
|
1502 |
+
deepseek_verified = gr.State(value=False)
|
1503 |
+
blip_description = gr.State(value="")
|
1504 |
+
deepseek_description = gr.State(value="")
|
1505 |
+
|
1506 |
+
|
1507 |
+
with gr.Row():
|
1508 |
+
with gr.Column():
|
1509 |
+
with gr.Row():
|
1510 |
+
input_image = gr.ImageEditor(
|
1511 |
+
label="Input Image",
|
1512 |
+
type="pil",
|
1513 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1514 |
+
layers = False,
|
1515 |
+
interactive=True,
|
1516 |
+
# height=1024,
|
1517 |
+
height=512,
|
1518 |
+
sources=["upload"],
|
1519 |
+
placeholder="Please click here or the icon below to upload the image.",
|
1520 |
+
)
|
1521 |
+
|
1522 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1523 |
+
run_button = gr.Button("💫 Run")
|
1524 |
+
|
1525 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1526 |
+
with gr.Group():
|
1527 |
+
with gr.Row():
|
1528 |
+
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1529 |
+
# GPT4o_KEY = gr.Textbox(type="password", value="sk-d145b963a92649a88843caeb741e8bbc")
|
1530 |
+
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1531 |
+
GPT4o_KEY_submit = gr.Button("Submit and Verify")
|
1532 |
+
|
1533 |
+
|
1534 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1535 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1536 |
+
|
1537 |
+
with gr.Row():
|
1538 |
+
mask_button = gr.Button("Generate Mask")
|
1539 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1540 |
+
|
1541 |
+
|
1542 |
+
with gr.Row():
|
1543 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1544 |
+
|
1545 |
+
target_prompt = gr.Text(
|
1546 |
+
label="Input Target Prompt",
|
1547 |
+
max_lines=5,
|
1548 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1549 |
+
value='',
|
1550 |
+
lines=2
|
1551 |
+
)
|
1552 |
+
|
1553 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1554 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1555 |
+
negative_prompt = gr.Text(
|
1556 |
+
label="Negative Prompt",
|
1557 |
+
max_lines=5,
|
1558 |
+
placeholder="Please input your negative prompt",
|
1559 |
+
value='ugly, low quality',lines=1
|
1560 |
+
)
|
1561 |
+
|
1562 |
+
control_strength = gr.Slider(
|
1563 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1564 |
+
)
|
1565 |
+
with gr.Group():
|
1566 |
+
seed = gr.Slider(
|
1567 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1568 |
+
)
|
1569 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1570 |
+
|
1571 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1572 |
+
|
1573 |
+
|
1574 |
+
num_samples = gr.Slider(
|
1575 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1576 |
+
)
|
1577 |
+
|
1578 |
+
with gr.Group():
|
1579 |
+
with gr.Row():
|
1580 |
+
guidance_scale = gr.Slider(
|
1581 |
+
label="Guidance scale",
|
1582 |
+
minimum=1,
|
1583 |
+
maximum=12,
|
1584 |
+
step=0.1,
|
1585 |
+
value=7.5,
|
1586 |
+
)
|
1587 |
+
num_inference_steps = gr.Slider(
|
1588 |
+
label="Number of inference steps",
|
1589 |
+
minimum=1,
|
1590 |
+
maximum=50,
|
1591 |
+
step=1,
|
1592 |
+
value=50,
|
1593 |
+
)
|
1594 |
+
|
1595 |
+
|
1596 |
+
with gr.Column():
|
1597 |
+
with gr.Group(visible=True):
|
1598 |
+
# BLIP生成的描述
|
1599 |
+
blip_output = gr.Textbox(label="BLIP生成描述", placeholder="自动生成的图像基础描述...", interactive=False, lines=3)
|
1600 |
+
# DeepSeek API验证
|
1601 |
+
with gr.Row():
|
1602 |
+
deepseek_key = gr.Textbox(label="DeepSeek API Key", value="sk-d145b963a92649a88843caeb741e8bbc", placeholder="输入DeepSeek API密钥进行增强", lines=1)
|
1603 |
+
verify_deepseek = gr.Button("Submit and Verify")
|
1604 |
+
# 整合后的描述
|
1605 |
+
deepseek_output = gr.Textbox(label="整合后描述", placeholder="DeepSeek生成的增强描述...", interactive=True, lines=3)
|
1606 |
+
|
1607 |
+
with gr.Row():
|
1608 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1609 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1610 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1611 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1612 |
+
|
1613 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1614 |
+
dilation_size = gr.Slider(
|
1615 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1616 |
+
)
|
1617 |
+
with gr.Row():
|
1618 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1619 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1620 |
+
|
1621 |
+
moving_pixels = gr.Slider(
|
1622 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1623 |
+
)
|
1624 |
+
with gr.Row():
|
1625 |
+
move_left_button = gr.Button("Move Left")
|
1626 |
+
move_right_button = gr.Button("Move Right")
|
1627 |
+
with gr.Row():
|
1628 |
+
move_up_button = gr.Button("Move Up")
|
1629 |
+
move_down_button = gr.Button("Move Down")
|
1630 |
+
|
1631 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1632 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1633 |
+
|
1634 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1635 |
+
|
1636 |
+
reset_button = gr.Button("Reset")
|
1637 |
+
|
1638 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1639 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1640 |
+
|
1641 |
+
|
1642 |
+
|
1643 |
+
with gr.Row():
|
1644 |
+
example = gr.Examples(
|
1645 |
+
label="Quick Example",
|
1646 |
+
examples=EXAMPLES,
|
1647 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1648 |
+
examples_per_page=10,
|
1649 |
+
cache_examples=False,
|
1650 |
+
)
|
1651 |
+
|
1652 |
+
|
1653 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1654 |
+
with gr.Row(equal_height=True):
|
1655 |
+
gr.Markdown(tips)
|
1656 |
+
|
1657 |
+
with gr.Row():
|
1658 |
+
gr.Markdown(citation)
|
1659 |
+
|
1660 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1661 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1662 |
+
input_image.upload(
|
1663 |
+
init_img,
|
1664 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1665 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1666 |
+
)
|
1667 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1668 |
+
|
1669 |
+
## vlm and base model dropdown
|
1670 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1671 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1672 |
+
|
1673 |
+
|
1674 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1675 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1676 |
+
|
1677 |
+
|
1678 |
+
ips=[input_image,
|
1679 |
+
original_image,
|
1680 |
+
original_mask,
|
1681 |
+
prompt,
|
1682 |
+
negative_prompt,
|
1683 |
+
control_strength,
|
1684 |
+
seed,
|
1685 |
+
randomize_seed,
|
1686 |
+
guidance_scale,
|
1687 |
+
num_inference_steps,
|
1688 |
+
num_samples,
|
1689 |
+
blending,
|
1690 |
+
category,
|
1691 |
+
target_prompt,
|
1692 |
+
resize_default,
|
1693 |
+
aspect_ratio,
|
1694 |
+
invert_mask_state]
|
1695 |
+
|
1696 |
+
## run brushedit
|
1697 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1698 |
+
|
1699 |
+
## mask func
|
1700 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1701 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1702 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1703 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1704 |
+
|
1705 |
+
## move mask func
|
1706 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1707 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1708 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1709 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1710 |
+
|
1711 |
+
## prompt func
|
1712 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1713 |
+
|
1714 |
+
## reset func
|
1715 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1716 |
+
|
1717 |
+
|
1718 |
+
# 绑定事件处理
|
1719 |
+
input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
|
1720 |
+
verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified]).success(fn=enhance_description, inputs=[prompt, blip_description], outputs=[deepseek_description, deepseek_output])
|
1721 |
+
# 当BLIP描述更新时自动触发增强(需验证通过)
|
1722 |
+
blip_description.change(fn=enhance_description, inputs=[prompt, blip_description], outputs=[deepseek_description, deepseek_output], preprocess=False)
|
1723 |
+
|
1724 |
+
# if have a localhost access error, try to use the following code
|
1725 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
brushedit_app_without_clip.py
ADDED
@@ -0,0 +1,1758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os, random, sys
|
4 |
+
import numpy as np
|
5 |
+
import requests
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
|
14 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
15 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
16 |
+
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
17 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
|
18 |
+
|
19 |
+
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
|
20 |
+
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
|
21 |
+
from diffusers.image_processor import VaeImageProcessor
|
22 |
+
|
23 |
+
|
24 |
+
from app.src.vlm_pipeline import (
|
25 |
+
vlm_response_editing_type,
|
26 |
+
vlm_response_object_wait_for_edit,
|
27 |
+
vlm_response_mask,
|
28 |
+
vlm_response_prompt_after_apply_instruction
|
29 |
+
)
|
30 |
+
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
|
31 |
+
from app.utils.utils import load_grounding_dino_model
|
32 |
+
|
33 |
+
from app.src.vlm_template import vlms_template
|
34 |
+
from app.src.base_model_template import base_models_template
|
35 |
+
from app.src.aspect_ratio_template import aspect_ratios
|
36 |
+
|
37 |
+
from openai import OpenAI
|
38 |
+
base_openai_url = "https://api.deepseek.com/"
|
39 |
+
base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
|
40 |
+
|
41 |
+
|
42 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
43 |
+
|
44 |
+
from app.deepseek.instructions import (
|
45 |
+
create_apply_editing_messages_deepseek,
|
46 |
+
create_decomposed_query_messages_deepseek
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
#### Description ####
|
51 |
+
logo = r"""
|
52 |
+
<center><img src='./assets/logo_brushedit.png' alt='BrushEdit logo' style="width:80px; margin-bottom:10px"></center>
|
53 |
+
"""
|
54 |
+
head = r"""
|
55 |
+
<div style="text-align: center;">
|
56 |
+
<h1> 基于扩散模型先验和大语言模型的零样本组合查询图像检索</h1>
|
57 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
58 |
+
<a href=''><img src='https://img.shields.io/badge/Project_Page-BrushEdit-green' alt='Project Page'></a>
|
59 |
+
<a href=''><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
|
60 |
+
<a href=''><img src='https://img.shields.io/badge/Code-Github-orange'></a>
|
61 |
+
|
62 |
+
</div>
|
63 |
+
</br>
|
64 |
+
</div>
|
65 |
+
"""
|
66 |
+
descriptions = r"""
|
67 |
+
Demo for ZS-CIR"""
|
68 |
+
|
69 |
+
instructions = r"""
|
70 |
+
Demo for ZS-CIR"""
|
71 |
+
|
72 |
+
tips = r"""
|
73 |
+
Demo for ZS-CIR
|
74 |
+
|
75 |
+
"""
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
citation = r"""
|
80 |
+
Demo for ZS-CIR"""
|
81 |
+
|
82 |
+
# - - - - - examples - - - - - #
|
83 |
+
EXAMPLES = [
|
84 |
+
|
85 |
+
[
|
86 |
+
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
|
87 |
+
"add a magic hat on frog head.",
|
88 |
+
642087011,
|
89 |
+
"frog",
|
90 |
+
"frog",
|
91 |
+
True,
|
92 |
+
False,
|
93 |
+
"GPT4-o (Highly Recommended)"
|
94 |
+
],
|
95 |
+
[
|
96 |
+
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
|
97 |
+
"replace the background to ancient China.",
|
98 |
+
648464818,
|
99 |
+
"chinese_girl",
|
100 |
+
"chinese_girl",
|
101 |
+
True,
|
102 |
+
False,
|
103 |
+
"GPT4-o (Highly Recommended)"
|
104 |
+
],
|
105 |
+
[
|
106 |
+
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
|
107 |
+
"remove the deer.",
|
108 |
+
648464818,
|
109 |
+
"angel_christmas",
|
110 |
+
"angel_christmas",
|
111 |
+
False,
|
112 |
+
False,
|
113 |
+
"GPT4-o (Highly Recommended)"
|
114 |
+
],
|
115 |
+
[
|
116 |
+
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
|
117 |
+
"add a wreath on head.",
|
118 |
+
648464818,
|
119 |
+
"sunflower_girl",
|
120 |
+
"sunflower_girl",
|
121 |
+
True,
|
122 |
+
False,
|
123 |
+
"GPT4-o (Highly Recommended)"
|
124 |
+
],
|
125 |
+
[
|
126 |
+
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
|
127 |
+
"add a butterfly fairy.",
|
128 |
+
648464818,
|
129 |
+
"girl_on_sun",
|
130 |
+
"girl_on_sun",
|
131 |
+
True,
|
132 |
+
False,
|
133 |
+
"GPT4-o (Highly Recommended)"
|
134 |
+
],
|
135 |
+
[
|
136 |
+
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
|
137 |
+
"remove the christmas hat.",
|
138 |
+
642087011,
|
139 |
+
"spider_man_rm",
|
140 |
+
"spider_man_rm",
|
141 |
+
False,
|
142 |
+
False,
|
143 |
+
"GPT4-o (Highly Recommended)"
|
144 |
+
],
|
145 |
+
[
|
146 |
+
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
|
147 |
+
"remove the flower.",
|
148 |
+
642087011,
|
149 |
+
"anime_flower",
|
150 |
+
"anime_flower",
|
151 |
+
False,
|
152 |
+
False,
|
153 |
+
"GPT4-o (Highly Recommended)"
|
154 |
+
],
|
155 |
+
[
|
156 |
+
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
|
157 |
+
"replace the clothes to a delicated floral skirt.",
|
158 |
+
648464818,
|
159 |
+
"chenduling",
|
160 |
+
"chenduling",
|
161 |
+
True,
|
162 |
+
False,
|
163 |
+
"GPT4-o (Highly Recommended)"
|
164 |
+
],
|
165 |
+
[
|
166 |
+
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
|
167 |
+
"make the hedgehog in Italy.",
|
168 |
+
648464818,
|
169 |
+
"hedgehog_rp_bg",
|
170 |
+
"hedgehog_rp_bg",
|
171 |
+
True,
|
172 |
+
False,
|
173 |
+
"GPT4-o (Highly Recommended)"
|
174 |
+
],
|
175 |
+
|
176 |
+
]
|
177 |
+
|
178 |
+
INPUT_IMAGE_PATH = {
|
179 |
+
"frog": "./assets/frog/frog.jpeg",
|
180 |
+
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
|
181 |
+
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
|
182 |
+
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
|
183 |
+
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
|
184 |
+
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
|
185 |
+
"anime_flower": "./assets/anime_flower/anime_flower.png",
|
186 |
+
"chenduling": "./assets/chenduling/chengduling.jpg",
|
187 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
|
188 |
+
}
|
189 |
+
MASK_IMAGE_PATH = {
|
190 |
+
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
191 |
+
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
192 |
+
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
193 |
+
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
194 |
+
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
|
195 |
+
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
196 |
+
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
197 |
+
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
198 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
199 |
+
}
|
200 |
+
MASKED_IMAGE_PATH = {
|
201 |
+
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
|
202 |
+
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
|
203 |
+
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
|
204 |
+
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
|
205 |
+
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
|
206 |
+
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
|
207 |
+
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
|
208 |
+
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
|
209 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
|
210 |
+
}
|
211 |
+
OUTPUT_IMAGE_PATH = {
|
212 |
+
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
|
213 |
+
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
|
214 |
+
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
|
215 |
+
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
|
216 |
+
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
|
217 |
+
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
|
218 |
+
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
|
219 |
+
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
|
220 |
+
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
|
221 |
+
}
|
222 |
+
|
223 |
+
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
|
224 |
+
# os.makedirs('gradio_temp_dir', exist_ok=True)
|
225 |
+
|
226 |
+
VLM_MODEL_NAMES = list(vlms_template.keys())
|
227 |
+
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
|
228 |
+
|
229 |
+
|
230 |
+
BASE_MODELS = list(base_models_template.keys())
|
231 |
+
DEFAULT_BASE_MODEL = "realisticVision (Default)"
|
232 |
+
|
233 |
+
ASPECT_RATIO_LABELS = list(aspect_ratios)
|
234 |
+
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
|
235 |
+
|
236 |
+
|
237 |
+
## init device
|
238 |
+
try:
|
239 |
+
if torch.cuda.is_available():
|
240 |
+
device = "cuda"
|
241 |
+
elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
242 |
+
device = "mps"
|
243 |
+
else:
|
244 |
+
device = "cpu"
|
245 |
+
except:
|
246 |
+
device = "cpu"
|
247 |
+
|
248 |
+
# ## init torch dtype
|
249 |
+
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
250 |
+
# torch_dtype = torch.bfloat16
|
251 |
+
# else:
|
252 |
+
# torch_dtype = torch.float16
|
253 |
+
|
254 |
+
# if device == "mps":
|
255 |
+
# torch_dtype = torch.float16
|
256 |
+
|
257 |
+
torch_dtype = torch.float16
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
# download hf models
|
262 |
+
BrushEdit_path = "models/"
|
263 |
+
if not os.path.exists(BrushEdit_path):
|
264 |
+
BrushEdit_path = snapshot_download(
|
265 |
+
repo_id="TencentARC/BrushEdit",
|
266 |
+
local_dir=BrushEdit_path,
|
267 |
+
token=os.getenv("HF_TOKEN"),
|
268 |
+
)
|
269 |
+
|
270 |
+
## init default VLM
|
271 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
|
272 |
+
if vlm_processor != "" and vlm_model != "":
|
273 |
+
vlm_model.to(device)
|
274 |
+
else:
|
275 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
276 |
+
|
277 |
+
## init default LLM
|
278 |
+
llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
|
279 |
+
|
280 |
+
## init base model
|
281 |
+
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
|
282 |
+
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
|
283 |
+
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
|
284 |
+
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
|
285 |
+
|
286 |
+
|
287 |
+
# input brushnetX ckpt path
|
288 |
+
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
|
289 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
290 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
291 |
+
)
|
292 |
+
# speed up diffusion process with faster scheduler and memory optimization
|
293 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
294 |
+
# remove following line if xformers is not installed or when using Torch 2.0.
|
295 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
296 |
+
pipe.enable_model_cpu_offload()
|
297 |
+
|
298 |
+
|
299 |
+
## init SAM
|
300 |
+
sam = build_sam(checkpoint=sam_path)
|
301 |
+
sam.to(device=device)
|
302 |
+
sam_predictor = SamPredictor(sam)
|
303 |
+
sam_automask_generator = SamAutomaticMaskGenerator(sam)
|
304 |
+
|
305 |
+
## init groundingdino_model
|
306 |
+
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
|
307 |
+
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
|
308 |
+
|
309 |
+
## Ordinary function
|
310 |
+
def crop_and_resize(image: Image.Image,
|
311 |
+
target_width: int,
|
312 |
+
target_height: int) -> Image.Image:
|
313 |
+
"""
|
314 |
+
Crops and resizes an image while preserving the aspect ratio.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
318 |
+
target_width (int): Target width of the output image.
|
319 |
+
target_height (int): Target height of the output image.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
Image.Image: Cropped and resized image.
|
323 |
+
"""
|
324 |
+
# Original dimensions
|
325 |
+
original_width, original_height = image.size
|
326 |
+
original_aspect = original_width / original_height
|
327 |
+
target_aspect = target_width / target_height
|
328 |
+
|
329 |
+
# Calculate crop box to maintain aspect ratio
|
330 |
+
if original_aspect > target_aspect:
|
331 |
+
# Crop horizontally
|
332 |
+
new_width = int(original_height * target_aspect)
|
333 |
+
new_height = original_height
|
334 |
+
left = (original_width - new_width) / 2
|
335 |
+
top = 0
|
336 |
+
right = left + new_width
|
337 |
+
bottom = original_height
|
338 |
+
else:
|
339 |
+
# Crop vertically
|
340 |
+
new_width = original_width
|
341 |
+
new_height = int(original_width / target_aspect)
|
342 |
+
left = 0
|
343 |
+
top = (original_height - new_height) / 2
|
344 |
+
right = original_width
|
345 |
+
bottom = top + new_height
|
346 |
+
|
347 |
+
# Crop and resize
|
348 |
+
cropped_image = image.crop((left, top, right, bottom))
|
349 |
+
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
|
350 |
+
return resized_image
|
351 |
+
|
352 |
+
|
353 |
+
## Ordinary function
|
354 |
+
def resize(image: Image.Image,
|
355 |
+
target_width: int,
|
356 |
+
target_height: int) -> Image.Image:
|
357 |
+
"""
|
358 |
+
Crops and resizes an image while preserving the aspect ratio.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
image (Image.Image): Input PIL image to be cropped and resized.
|
362 |
+
target_width (int): Target width of the output image.
|
363 |
+
target_height (int): Target height of the output image.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Image.Image: Cropped and resized image.
|
367 |
+
"""
|
368 |
+
# Original dimensions
|
369 |
+
resized_image = image.resize((target_width, target_height), Image.NEAREST)
|
370 |
+
return resized_image
|
371 |
+
|
372 |
+
|
373 |
+
def move_mask_func(mask, direction, units):
|
374 |
+
binary_mask = mask.squeeze()>0
|
375 |
+
rows, cols = binary_mask.shape
|
376 |
+
moved_mask = np.zeros_like(binary_mask, dtype=bool)
|
377 |
+
|
378 |
+
if direction == 'down':
|
379 |
+
# move down
|
380 |
+
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
|
381 |
+
|
382 |
+
elif direction == 'up':
|
383 |
+
# move up
|
384 |
+
moved_mask[:rows - units, :] = binary_mask[units:, :]
|
385 |
+
|
386 |
+
elif direction == 'right':
|
387 |
+
# move left
|
388 |
+
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
|
389 |
+
|
390 |
+
elif direction == 'left':
|
391 |
+
# move right
|
392 |
+
moved_mask[:, :cols - units] = binary_mask[:, units:]
|
393 |
+
|
394 |
+
return moved_mask
|
395 |
+
|
396 |
+
|
397 |
+
def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
398 |
+
# Randomly select the size of dilation
|
399 |
+
binary_mask = mask.squeeze()>0
|
400 |
+
|
401 |
+
if dilation_type == 'square_dilation':
|
402 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
403 |
+
dilated_mask = binary_dilation(binary_mask, structure=structure)
|
404 |
+
elif dilation_type == 'square_erosion':
|
405 |
+
structure = np.ones((dilation_size, dilation_size), dtype=bool)
|
406 |
+
dilated_mask = binary_erosion(binary_mask, structure=structure)
|
407 |
+
elif dilation_type == 'bounding_box':
|
408 |
+
# find the most left top and left bottom point
|
409 |
+
rows, cols = np.where(binary_mask)
|
410 |
+
if len(rows) == 0 or len(cols) == 0:
|
411 |
+
return mask # return original mask if no valid points
|
412 |
+
|
413 |
+
min_row = np.min(rows)
|
414 |
+
max_row = np.max(rows)
|
415 |
+
min_col = np.min(cols)
|
416 |
+
max_col = np.max(cols)
|
417 |
+
|
418 |
+
# create a bounding box
|
419 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
420 |
+
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
|
421 |
+
|
422 |
+
elif dilation_type == 'bounding_ellipse':
|
423 |
+
# find the most left top and left bottom point
|
424 |
+
rows, cols = np.where(binary_mask)
|
425 |
+
if len(rows) == 0 or len(cols) == 0:
|
426 |
+
return mask # return original mask if no valid points
|
427 |
+
|
428 |
+
min_row = np.min(rows)
|
429 |
+
max_row = np.max(rows)
|
430 |
+
min_col = np.min(cols)
|
431 |
+
max_col = np.max(cols)
|
432 |
+
|
433 |
+
# calculate the center and axis length of the ellipse
|
434 |
+
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
|
435 |
+
a = (max_col - min_col) // 2 # half long axis
|
436 |
+
b = (max_row - min_row) // 2 # half short axis
|
437 |
+
|
438 |
+
# create a bounding ellipse
|
439 |
+
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
|
440 |
+
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
|
441 |
+
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
442 |
+
dilated_mask[ellipse_mask] = True
|
443 |
+
else:
|
444 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
445 |
+
|
446 |
+
# use binary dilation
|
447 |
+
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
448 |
+
return dilated_mask
|
449 |
+
|
450 |
+
|
451 |
+
## Gradio component function
|
452 |
+
def update_vlm_model(vlm_name):
|
453 |
+
global vlm_model, vlm_processor
|
454 |
+
if vlm_model is not None:
|
455 |
+
del vlm_model
|
456 |
+
torch.cuda.empty_cache()
|
457 |
+
|
458 |
+
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
|
459 |
+
|
460 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
|
461 |
+
if vlm_type == "llava-next":
|
462 |
+
if vlm_processor != "" and vlm_model != "":
|
463 |
+
vlm_model.to(device)
|
464 |
+
return vlm_model_dropdown
|
465 |
+
else:
|
466 |
+
if os.path.exists(vlm_local_path):
|
467 |
+
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
|
468 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
469 |
+
else:
|
470 |
+
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
|
471 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
472 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
|
473 |
+
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
|
474 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
475 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
|
476 |
+
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
|
477 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
|
478 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
|
479 |
+
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
|
480 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
|
481 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
|
482 |
+
elif vlm_name == "llava-next-72b-hf (Preload)":
|
483 |
+
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
|
484 |
+
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
|
485 |
+
elif vlm_type == "qwen2-vl":
|
486 |
+
if vlm_processor != "" and vlm_model != "":
|
487 |
+
vlm_model.to(device)
|
488 |
+
return vlm_model_dropdown
|
489 |
+
else:
|
490 |
+
if os.path.exists(vlm_local_path):
|
491 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
|
492 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
|
493 |
+
else:
|
494 |
+
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
|
495 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
496 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
|
497 |
+
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
|
498 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
499 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
|
500 |
+
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
|
501 |
+
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
|
502 |
+
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
|
503 |
+
elif vlm_type == "openai":
|
504 |
+
pass
|
505 |
+
return "success"
|
506 |
+
|
507 |
+
|
508 |
+
def update_base_model(base_model_name):
|
509 |
+
global pipe
|
510 |
+
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
|
511 |
+
if pipe is not None:
|
512 |
+
del pipe
|
513 |
+
torch.cuda.empty_cache()
|
514 |
+
base_model_path, pipe = base_models_template[base_model_name]
|
515 |
+
if pipe != "":
|
516 |
+
pipe.to(device)
|
517 |
+
else:
|
518 |
+
if os.path.exists(base_model_path):
|
519 |
+
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
|
520 |
+
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
|
521 |
+
)
|
522 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
523 |
+
pipe.enable_model_cpu_offload()
|
524 |
+
else:
|
525 |
+
raise gr.Error(f"The base model {base_model_name} does not exist")
|
526 |
+
return "success"
|
527 |
+
|
528 |
+
|
529 |
+
def process(input_image,
|
530 |
+
original_image,
|
531 |
+
original_mask,
|
532 |
+
prompt,
|
533 |
+
negative_prompt,
|
534 |
+
control_strength,
|
535 |
+
seed,
|
536 |
+
randomize_seed,
|
537 |
+
guidance_scale,
|
538 |
+
num_inference_steps,
|
539 |
+
num_samples,
|
540 |
+
blending,
|
541 |
+
category,
|
542 |
+
target_prompt,
|
543 |
+
resize_default,
|
544 |
+
aspect_ratio_name,
|
545 |
+
invert_mask_state):
|
546 |
+
if original_image is None:
|
547 |
+
if input_image is None:
|
548 |
+
raise gr.Error('Please upload the input image')
|
549 |
+
else:
|
550 |
+
print("input_image的键:", input_image.keys()) # 打印字典键
|
551 |
+
image_pil = input_image["background"].convert("RGB")
|
552 |
+
original_image = np.array(image_pil)
|
553 |
+
if prompt is None or prompt == "":
|
554 |
+
if target_prompt is None or target_prompt == "":
|
555 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
556 |
+
|
557 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
558 |
+
input_mask = np.asarray(alpha_mask)
|
559 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
560 |
+
if output_w == "" or output_h == "":
|
561 |
+
output_h, output_w = original_image.shape[:2]
|
562 |
+
|
563 |
+
if resize_default:
|
564 |
+
short_side = min(output_w, output_h)
|
565 |
+
scale_ratio = 640 / short_side
|
566 |
+
output_w = int(output_w * scale_ratio)
|
567 |
+
output_h = int(output_h * scale_ratio)
|
568 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
569 |
+
original_image = np.array(original_image)
|
570 |
+
if input_mask is not None:
|
571 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
572 |
+
input_mask = np.array(input_mask)
|
573 |
+
if original_mask is not None:
|
574 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
575 |
+
original_mask = np.array(original_mask)
|
576 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
577 |
+
else:
|
578 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
579 |
+
pass
|
580 |
+
else:
|
581 |
+
if resize_default:
|
582 |
+
short_side = min(output_w, output_h)
|
583 |
+
scale_ratio = 640 / short_side
|
584 |
+
output_w = int(output_w * scale_ratio)
|
585 |
+
output_h = int(output_h * scale_ratio)
|
586 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
587 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
588 |
+
original_image = np.array(original_image)
|
589 |
+
if input_mask is not None:
|
590 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
591 |
+
input_mask = np.array(input_mask)
|
592 |
+
if original_mask is not None:
|
593 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
594 |
+
original_mask = np.array(original_mask)
|
595 |
+
|
596 |
+
if invert_mask_state:
|
597 |
+
original_mask = original_mask
|
598 |
+
else:
|
599 |
+
if input_mask.max() == 0:
|
600 |
+
original_mask = original_mask
|
601 |
+
else:
|
602 |
+
original_mask = input_mask
|
603 |
+
|
604 |
+
|
605 |
+
## inpainting directly if target_prompt is not None
|
606 |
+
if category is not None:
|
607 |
+
pass
|
608 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
609 |
+
pass
|
610 |
+
else:
|
611 |
+
try:
|
612 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
613 |
+
except Exception as e:
|
614 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
615 |
+
|
616 |
+
|
617 |
+
if original_mask is not None:
|
618 |
+
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
619 |
+
else:
|
620 |
+
try:
|
621 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
622 |
+
vlm_processor,
|
623 |
+
vlm_model,
|
624 |
+
original_image,
|
625 |
+
category,
|
626 |
+
prompt,
|
627 |
+
device)
|
628 |
+
|
629 |
+
original_mask = vlm_response_mask(vlm_processor,
|
630 |
+
vlm_model,
|
631 |
+
category,
|
632 |
+
original_image,
|
633 |
+
prompt,
|
634 |
+
object_wait_for_edit,
|
635 |
+
sam,
|
636 |
+
sam_predictor,
|
637 |
+
sam_automask_generator,
|
638 |
+
groundingdino_model,
|
639 |
+
device).astype(np.uint8)
|
640 |
+
except Exception as e:
|
641 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
642 |
+
|
643 |
+
if original_mask.ndim == 2:
|
644 |
+
original_mask = original_mask[:,:,None]
|
645 |
+
|
646 |
+
|
647 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
648 |
+
prompt_after_apply_instruction = target_prompt
|
649 |
+
|
650 |
+
else:
|
651 |
+
try:
|
652 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
653 |
+
vlm_processor,
|
654 |
+
vlm_model,
|
655 |
+
original_image,
|
656 |
+
prompt,
|
657 |
+
device)
|
658 |
+
except Exception as e:
|
659 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
660 |
+
|
661 |
+
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
662 |
+
|
663 |
+
|
664 |
+
with torch.autocast(device):
|
665 |
+
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
|
666 |
+
prompt_after_apply_instruction,
|
667 |
+
original_mask,
|
668 |
+
original_image,
|
669 |
+
generator,
|
670 |
+
num_inference_steps,
|
671 |
+
guidance_scale,
|
672 |
+
control_strength,
|
673 |
+
negative_prompt,
|
674 |
+
num_samples,
|
675 |
+
blending)
|
676 |
+
original_image = np.array(init_image_np)
|
677 |
+
masked_image = original_image * (1 - (mask_np>0))
|
678 |
+
masked_image = masked_image.astype(np.uint8)
|
679 |
+
masked_image = Image.fromarray(masked_image)
|
680 |
+
# Save the images (optional)
|
681 |
+
# import uuid
|
682 |
+
# uuid = str(uuid.uuid4())
|
683 |
+
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
|
684 |
+
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
|
685 |
+
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
|
686 |
+
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
687 |
+
# mask_image.save(f"outputs/mask_{uuid}.png")
|
688 |
+
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
689 |
+
gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20)
|
690 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
691 |
+
|
692 |
+
|
693 |
+
def process_mask(input_image,
|
694 |
+
original_image,
|
695 |
+
prompt,
|
696 |
+
resize_default,
|
697 |
+
aspect_ratio_name):
|
698 |
+
if original_image is None:
|
699 |
+
raise gr.Error('Please upload the input image')
|
700 |
+
if prompt is None:
|
701 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
702 |
+
|
703 |
+
## load mask
|
704 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
705 |
+
input_mask = np.array(alpha_mask)
|
706 |
+
|
707 |
+
# load example image
|
708 |
+
if isinstance(original_image, str):
|
709 |
+
original_image = input_image["background"]
|
710 |
+
|
711 |
+
if input_mask.max() == 0:
|
712 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
713 |
+
|
714 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
|
715 |
+
vlm_model,
|
716 |
+
original_image,
|
717 |
+
category,
|
718 |
+
prompt,
|
719 |
+
device)
|
720 |
+
# original mask: h,w,1 [0, 255]
|
721 |
+
original_mask = vlm_response_mask(
|
722 |
+
vlm_processor,
|
723 |
+
vlm_model,
|
724 |
+
category,
|
725 |
+
original_image,
|
726 |
+
prompt,
|
727 |
+
object_wait_for_edit,
|
728 |
+
sam,
|
729 |
+
sam_predictor,
|
730 |
+
sam_automask_generator,
|
731 |
+
groundingdino_model,
|
732 |
+
device).astype(np.uint8)
|
733 |
+
else:
|
734 |
+
original_mask = input_mask.astype(np.uint8)
|
735 |
+
category = None
|
736 |
+
|
737 |
+
## resize mask if needed
|
738 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
739 |
+
if output_w == "" or output_h == "":
|
740 |
+
output_h, output_w = original_image.shape[:2]
|
741 |
+
if resize_default:
|
742 |
+
short_side = min(output_w, output_h)
|
743 |
+
scale_ratio = 640 / short_side
|
744 |
+
output_w = int(output_w * scale_ratio)
|
745 |
+
output_h = int(output_h * scale_ratio)
|
746 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
747 |
+
original_image = np.array(original_image)
|
748 |
+
if input_mask is not None:
|
749 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
750 |
+
input_mask = np.array(input_mask)
|
751 |
+
if original_mask is not None:
|
752 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
753 |
+
original_mask = np.array(original_mask)
|
754 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
755 |
+
else:
|
756 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
757 |
+
pass
|
758 |
+
else:
|
759 |
+
if resize_default:
|
760 |
+
short_side = min(output_w, output_h)
|
761 |
+
scale_ratio = 640 / short_side
|
762 |
+
output_w = int(output_w * scale_ratio)
|
763 |
+
output_h = int(output_h * scale_ratio)
|
764 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
765 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
766 |
+
original_image = np.array(original_image)
|
767 |
+
if input_mask is not None:
|
768 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
769 |
+
input_mask = np.array(input_mask)
|
770 |
+
if original_mask is not None:
|
771 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
772 |
+
original_mask = np.array(original_mask)
|
773 |
+
|
774 |
+
|
775 |
+
if original_mask.ndim == 2:
|
776 |
+
original_mask = original_mask[:,:,None]
|
777 |
+
|
778 |
+
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
|
779 |
+
|
780 |
+
masked_image = original_image * (1 - (original_mask>0))
|
781 |
+
masked_image = masked_image.astype(np.uint8)
|
782 |
+
masked_image = Image.fromarray(masked_image)
|
783 |
+
|
784 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
|
785 |
+
|
786 |
+
|
787 |
+
def process_random_mask(input_image,
|
788 |
+
original_image,
|
789 |
+
original_mask,
|
790 |
+
resize_default,
|
791 |
+
aspect_ratio_name,
|
792 |
+
):
|
793 |
+
|
794 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
795 |
+
input_mask = np.asarray(alpha_mask)
|
796 |
+
|
797 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
798 |
+
if output_w == "" or output_h == "":
|
799 |
+
output_h, output_w = original_image.shape[:2]
|
800 |
+
if resize_default:
|
801 |
+
short_side = min(output_w, output_h)
|
802 |
+
scale_ratio = 640 / short_side
|
803 |
+
output_w = int(output_w * scale_ratio)
|
804 |
+
output_h = int(output_h * scale_ratio)
|
805 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
806 |
+
original_image = np.array(original_image)
|
807 |
+
if input_mask is not None:
|
808 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
809 |
+
input_mask = np.array(input_mask)
|
810 |
+
if original_mask is not None:
|
811 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
812 |
+
original_mask = np.array(original_mask)
|
813 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
814 |
+
else:
|
815 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
816 |
+
pass
|
817 |
+
else:
|
818 |
+
if resize_default:
|
819 |
+
short_side = min(output_w, output_h)
|
820 |
+
scale_ratio = 640 / short_side
|
821 |
+
output_w = int(output_w * scale_ratio)
|
822 |
+
output_h = int(output_h * scale_ratio)
|
823 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
824 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
825 |
+
original_image = np.array(original_image)
|
826 |
+
if input_mask is not None:
|
827 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
828 |
+
input_mask = np.array(input_mask)
|
829 |
+
if original_mask is not None:
|
830 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
831 |
+
original_mask = np.array(original_mask)
|
832 |
+
|
833 |
+
|
834 |
+
if input_mask.max() == 0:
|
835 |
+
original_mask = original_mask
|
836 |
+
else:
|
837 |
+
original_mask = input_mask
|
838 |
+
|
839 |
+
if original_mask is None:
|
840 |
+
raise gr.Error('Please generate mask first')
|
841 |
+
|
842 |
+
if original_mask.ndim == 2:
|
843 |
+
original_mask = original_mask[:,:,None]
|
844 |
+
|
845 |
+
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
|
846 |
+
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
|
847 |
+
|
848 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
849 |
+
|
850 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
851 |
+
masked_image = masked_image.astype(original_image.dtype)
|
852 |
+
masked_image = Image.fromarray(masked_image)
|
853 |
+
|
854 |
+
|
855 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
856 |
+
|
857 |
+
|
858 |
+
def process_dilation_mask(input_image,
|
859 |
+
original_image,
|
860 |
+
original_mask,
|
861 |
+
resize_default,
|
862 |
+
aspect_ratio_name,
|
863 |
+
dilation_size=20):
|
864 |
+
|
865 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
866 |
+
input_mask = np.asarray(alpha_mask)
|
867 |
+
|
868 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
869 |
+
if output_w == "" or output_h == "":
|
870 |
+
output_h, output_w = original_image.shape[:2]
|
871 |
+
if resize_default:
|
872 |
+
short_side = min(output_w, output_h)
|
873 |
+
scale_ratio = 640 / short_side
|
874 |
+
output_w = int(output_w * scale_ratio)
|
875 |
+
output_h = int(output_h * scale_ratio)
|
876 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
877 |
+
original_image = np.array(original_image)
|
878 |
+
if input_mask is not None:
|
879 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
880 |
+
input_mask = np.array(input_mask)
|
881 |
+
if original_mask is not None:
|
882 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
883 |
+
original_mask = np.array(original_mask)
|
884 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
885 |
+
else:
|
886 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
887 |
+
pass
|
888 |
+
else:
|
889 |
+
if resize_default:
|
890 |
+
short_side = min(output_w, output_h)
|
891 |
+
scale_ratio = 640 / short_side
|
892 |
+
output_w = int(output_w * scale_ratio)
|
893 |
+
output_h = int(output_h * scale_ratio)
|
894 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
895 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
896 |
+
original_image = np.array(original_image)
|
897 |
+
if input_mask is not None:
|
898 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
899 |
+
input_mask = np.array(input_mask)
|
900 |
+
if original_mask is not None:
|
901 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
902 |
+
original_mask = np.array(original_mask)
|
903 |
+
|
904 |
+
if input_mask.max() == 0:
|
905 |
+
original_mask = original_mask
|
906 |
+
else:
|
907 |
+
original_mask = input_mask
|
908 |
+
|
909 |
+
if original_mask is None:
|
910 |
+
raise gr.Error('Please generate mask first')
|
911 |
+
|
912 |
+
if original_mask.ndim == 2:
|
913 |
+
original_mask = original_mask[:,:,None]
|
914 |
+
|
915 |
+
dilation_type = np.random.choice(['square_dilation'])
|
916 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
917 |
+
|
918 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
919 |
+
|
920 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
921 |
+
masked_image = masked_image.astype(original_image.dtype)
|
922 |
+
masked_image = Image.fromarray(masked_image)
|
923 |
+
|
924 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
925 |
+
|
926 |
+
|
927 |
+
def process_erosion_mask(input_image,
|
928 |
+
original_image,
|
929 |
+
original_mask,
|
930 |
+
resize_default,
|
931 |
+
aspect_ratio_name,
|
932 |
+
dilation_size=20):
|
933 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
934 |
+
input_mask = np.asarray(alpha_mask)
|
935 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
936 |
+
if output_w == "" or output_h == "":
|
937 |
+
output_h, output_w = original_image.shape[:2]
|
938 |
+
if resize_default:
|
939 |
+
short_side = min(output_w, output_h)
|
940 |
+
scale_ratio = 640 / short_side
|
941 |
+
output_w = int(output_w * scale_ratio)
|
942 |
+
output_h = int(output_h * scale_ratio)
|
943 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
944 |
+
original_image = np.array(original_image)
|
945 |
+
if input_mask is not None:
|
946 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
947 |
+
input_mask = np.array(input_mask)
|
948 |
+
if original_mask is not None:
|
949 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
950 |
+
original_mask = np.array(original_mask)
|
951 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
952 |
+
else:
|
953 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
954 |
+
pass
|
955 |
+
else:
|
956 |
+
if resize_default:
|
957 |
+
short_side = min(output_w, output_h)
|
958 |
+
scale_ratio = 640 / short_side
|
959 |
+
output_w = int(output_w * scale_ratio)
|
960 |
+
output_h = int(output_h * scale_ratio)
|
961 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
962 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
963 |
+
original_image = np.array(original_image)
|
964 |
+
if input_mask is not None:
|
965 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
966 |
+
input_mask = np.array(input_mask)
|
967 |
+
if original_mask is not None:
|
968 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
969 |
+
original_mask = np.array(original_mask)
|
970 |
+
|
971 |
+
if input_mask.max() == 0:
|
972 |
+
original_mask = original_mask
|
973 |
+
else:
|
974 |
+
original_mask = input_mask
|
975 |
+
|
976 |
+
if original_mask is None:
|
977 |
+
raise gr.Error('Please generate mask first')
|
978 |
+
|
979 |
+
if original_mask.ndim == 2:
|
980 |
+
original_mask = original_mask[:,:,None]
|
981 |
+
|
982 |
+
dilation_type = np.random.choice(['square_erosion'])
|
983 |
+
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
|
984 |
+
|
985 |
+
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
|
986 |
+
|
987 |
+
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
|
988 |
+
masked_image = masked_image.astype(original_image.dtype)
|
989 |
+
masked_image = Image.fromarray(masked_image)
|
990 |
+
|
991 |
+
|
992 |
+
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
|
993 |
+
|
994 |
+
|
995 |
+
def move_mask_left(input_image,
|
996 |
+
original_image,
|
997 |
+
original_mask,
|
998 |
+
moving_pixels,
|
999 |
+
resize_default,
|
1000 |
+
aspect_ratio_name):
|
1001 |
+
|
1002 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1003 |
+
input_mask = np.asarray(alpha_mask)
|
1004 |
+
|
1005 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1006 |
+
if output_w == "" or output_h == "":
|
1007 |
+
output_h, output_w = original_image.shape[:2]
|
1008 |
+
if resize_default:
|
1009 |
+
short_side = min(output_w, output_h)
|
1010 |
+
scale_ratio = 640 / short_side
|
1011 |
+
output_w = int(output_w * scale_ratio)
|
1012 |
+
output_h = int(output_h * scale_ratio)
|
1013 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1014 |
+
original_image = np.array(original_image)
|
1015 |
+
if input_mask is not None:
|
1016 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1017 |
+
input_mask = np.array(input_mask)
|
1018 |
+
if original_mask is not None:
|
1019 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1020 |
+
original_mask = np.array(original_mask)
|
1021 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1022 |
+
else:
|
1023 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1024 |
+
pass
|
1025 |
+
else:
|
1026 |
+
if resize_default:
|
1027 |
+
short_side = min(output_w, output_h)
|
1028 |
+
scale_ratio = 640 / short_side
|
1029 |
+
output_w = int(output_w * scale_ratio)
|
1030 |
+
output_h = int(output_h * scale_ratio)
|
1031 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1032 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1033 |
+
original_image = np.array(original_image)
|
1034 |
+
if input_mask is not None:
|
1035 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1036 |
+
input_mask = np.array(input_mask)
|
1037 |
+
if original_mask is not None:
|
1038 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1039 |
+
original_mask = np.array(original_mask)
|
1040 |
+
|
1041 |
+
if input_mask.max() == 0:
|
1042 |
+
original_mask = original_mask
|
1043 |
+
else:
|
1044 |
+
original_mask = input_mask
|
1045 |
+
|
1046 |
+
if original_mask is None:
|
1047 |
+
raise gr.Error('Please generate mask first')
|
1048 |
+
|
1049 |
+
if original_mask.ndim == 2:
|
1050 |
+
original_mask = original_mask[:,:,None]
|
1051 |
+
|
1052 |
+
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
|
1053 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1054 |
+
|
1055 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1056 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1057 |
+
masked_image = Image.fromarray(masked_image)
|
1058 |
+
|
1059 |
+
if moved_mask.max() <= 1:
|
1060 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1061 |
+
original_mask = moved_mask
|
1062 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1063 |
+
|
1064 |
+
|
1065 |
+
def move_mask_right(input_image,
|
1066 |
+
original_image,
|
1067 |
+
original_mask,
|
1068 |
+
moving_pixels,
|
1069 |
+
resize_default,
|
1070 |
+
aspect_ratio_name):
|
1071 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1072 |
+
input_mask = np.asarray(alpha_mask)
|
1073 |
+
|
1074 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1075 |
+
if output_w == "" or output_h == "":
|
1076 |
+
output_h, output_w = original_image.shape[:2]
|
1077 |
+
if resize_default:
|
1078 |
+
short_side = min(output_w, output_h)
|
1079 |
+
scale_ratio = 640 / short_side
|
1080 |
+
output_w = int(output_w * scale_ratio)
|
1081 |
+
output_h = int(output_h * scale_ratio)
|
1082 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1083 |
+
original_image = np.array(original_image)
|
1084 |
+
if input_mask is not None:
|
1085 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1086 |
+
input_mask = np.array(input_mask)
|
1087 |
+
if original_mask is not None:
|
1088 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1089 |
+
original_mask = np.array(original_mask)
|
1090 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1091 |
+
else:
|
1092 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1093 |
+
pass
|
1094 |
+
else:
|
1095 |
+
if resize_default:
|
1096 |
+
short_side = min(output_w, output_h)
|
1097 |
+
scale_ratio = 640 / short_side
|
1098 |
+
output_w = int(output_w * scale_ratio)
|
1099 |
+
output_h = int(output_h * scale_ratio)
|
1100 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1101 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1102 |
+
original_image = np.array(original_image)
|
1103 |
+
if input_mask is not None:
|
1104 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1105 |
+
input_mask = np.array(input_mask)
|
1106 |
+
if original_mask is not None:
|
1107 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1108 |
+
original_mask = np.array(original_mask)
|
1109 |
+
|
1110 |
+
if input_mask.max() == 0:
|
1111 |
+
original_mask = original_mask
|
1112 |
+
else:
|
1113 |
+
original_mask = input_mask
|
1114 |
+
|
1115 |
+
if original_mask is None:
|
1116 |
+
raise gr.Error('Please generate mask first')
|
1117 |
+
|
1118 |
+
if original_mask.ndim == 2:
|
1119 |
+
original_mask = original_mask[:,:,None]
|
1120 |
+
|
1121 |
+
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
|
1122 |
+
|
1123 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1124 |
+
|
1125 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1126 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1127 |
+
masked_image = Image.fromarray(masked_image)
|
1128 |
+
|
1129 |
+
|
1130 |
+
if moved_mask.max() <= 1:
|
1131 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1132 |
+
original_mask = moved_mask
|
1133 |
+
|
1134 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1135 |
+
|
1136 |
+
|
1137 |
+
def move_mask_up(input_image,
|
1138 |
+
original_image,
|
1139 |
+
original_mask,
|
1140 |
+
moving_pixels,
|
1141 |
+
resize_default,
|
1142 |
+
aspect_ratio_name):
|
1143 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1144 |
+
input_mask = np.asarray(alpha_mask)
|
1145 |
+
|
1146 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1147 |
+
if output_w == "" or output_h == "":
|
1148 |
+
output_h, output_w = original_image.shape[:2]
|
1149 |
+
if resize_default:
|
1150 |
+
short_side = min(output_w, output_h)
|
1151 |
+
scale_ratio = 640 / short_side
|
1152 |
+
output_w = int(output_w * scale_ratio)
|
1153 |
+
output_h = int(output_h * scale_ratio)
|
1154 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1155 |
+
original_image = np.array(original_image)
|
1156 |
+
if input_mask is not None:
|
1157 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1158 |
+
input_mask = np.array(input_mask)
|
1159 |
+
if original_mask is not None:
|
1160 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1161 |
+
original_mask = np.array(original_mask)
|
1162 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1163 |
+
else:
|
1164 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1165 |
+
pass
|
1166 |
+
else:
|
1167 |
+
if resize_default:
|
1168 |
+
short_side = min(output_w, output_h)
|
1169 |
+
scale_ratio = 640 / short_side
|
1170 |
+
output_w = int(output_w * scale_ratio)
|
1171 |
+
output_h = int(output_h * scale_ratio)
|
1172 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1173 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1174 |
+
original_image = np.array(original_image)
|
1175 |
+
if input_mask is not None:
|
1176 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1177 |
+
input_mask = np.array(input_mask)
|
1178 |
+
if original_mask is not None:
|
1179 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1180 |
+
original_mask = np.array(original_mask)
|
1181 |
+
|
1182 |
+
if input_mask.max() == 0:
|
1183 |
+
original_mask = original_mask
|
1184 |
+
else:
|
1185 |
+
original_mask = input_mask
|
1186 |
+
|
1187 |
+
if original_mask is None:
|
1188 |
+
raise gr.Error('Please generate mask first')
|
1189 |
+
|
1190 |
+
if original_mask.ndim == 2:
|
1191 |
+
original_mask = original_mask[:,:,None]
|
1192 |
+
|
1193 |
+
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
|
1194 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1195 |
+
|
1196 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1197 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1198 |
+
masked_image = Image.fromarray(masked_image)
|
1199 |
+
|
1200 |
+
if moved_mask.max() <= 1:
|
1201 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1202 |
+
original_mask = moved_mask
|
1203 |
+
|
1204 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1205 |
+
|
1206 |
+
|
1207 |
+
def move_mask_down(input_image,
|
1208 |
+
original_image,
|
1209 |
+
original_mask,
|
1210 |
+
moving_pixels,
|
1211 |
+
resize_default,
|
1212 |
+
aspect_ratio_name):
|
1213 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1214 |
+
input_mask = np.asarray(alpha_mask)
|
1215 |
+
output_w, output_h = aspect_ratios[aspect_ratio_name]
|
1216 |
+
if output_w == "" or output_h == "":
|
1217 |
+
output_h, output_w = original_image.shape[:2]
|
1218 |
+
if resize_default:
|
1219 |
+
short_side = min(output_w, output_h)
|
1220 |
+
scale_ratio = 640 / short_side
|
1221 |
+
output_w = int(output_w * scale_ratio)
|
1222 |
+
output_h = int(output_h * scale_ratio)
|
1223 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1224 |
+
original_image = np.array(original_image)
|
1225 |
+
if input_mask is not None:
|
1226 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1227 |
+
input_mask = np.array(input_mask)
|
1228 |
+
if original_mask is not None:
|
1229 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1230 |
+
original_mask = np.array(original_mask)
|
1231 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1232 |
+
else:
|
1233 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1234 |
+
pass
|
1235 |
+
else:
|
1236 |
+
if resize_default:
|
1237 |
+
short_side = min(output_w, output_h)
|
1238 |
+
scale_ratio = 640 / short_side
|
1239 |
+
output_w = int(output_w * scale_ratio)
|
1240 |
+
output_h = int(output_h * scale_ratio)
|
1241 |
+
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
|
1242 |
+
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
|
1243 |
+
original_image = np.array(original_image)
|
1244 |
+
if input_mask is not None:
|
1245 |
+
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
|
1246 |
+
input_mask = np.array(input_mask)
|
1247 |
+
if original_mask is not None:
|
1248 |
+
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
|
1249 |
+
original_mask = np.array(original_mask)
|
1250 |
+
|
1251 |
+
if input_mask.max() == 0:
|
1252 |
+
original_mask = original_mask
|
1253 |
+
else:
|
1254 |
+
original_mask = input_mask
|
1255 |
+
|
1256 |
+
if original_mask is None:
|
1257 |
+
raise gr.Error('Please generate mask first')
|
1258 |
+
|
1259 |
+
if original_mask.ndim == 2:
|
1260 |
+
original_mask = original_mask[:,:,None]
|
1261 |
+
|
1262 |
+
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
|
1263 |
+
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
|
1264 |
+
|
1265 |
+
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
|
1266 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1267 |
+
masked_image = Image.fromarray(masked_image)
|
1268 |
+
|
1269 |
+
if moved_mask.max() <= 1:
|
1270 |
+
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
|
1271 |
+
original_mask = moved_mask
|
1272 |
+
|
1273 |
+
return [masked_image], [mask_image], original_mask.astype(np.uint8)
|
1274 |
+
|
1275 |
+
|
1276 |
+
def invert_mask(input_image,
|
1277 |
+
original_image,
|
1278 |
+
original_mask,
|
1279 |
+
):
|
1280 |
+
alpha_mask = input_image["layers"][0].split()[3]
|
1281 |
+
input_mask = np.asarray(alpha_mask)
|
1282 |
+
if input_mask.max() == 0:
|
1283 |
+
original_mask = 1 - (original_mask>0).astype(np.uint8)
|
1284 |
+
else:
|
1285 |
+
original_mask = 1 - (input_mask>0).astype(np.uint8)
|
1286 |
+
|
1287 |
+
if original_mask is None:
|
1288 |
+
raise gr.Error('Please generate mask first')
|
1289 |
+
|
1290 |
+
original_mask = original_mask.squeeze()
|
1291 |
+
mask_image = Image.fromarray(original_mask*255).convert("RGB")
|
1292 |
+
|
1293 |
+
if original_mask.ndim == 2:
|
1294 |
+
original_mask = original_mask[:,:,None]
|
1295 |
+
|
1296 |
+
if original_mask.max() <= 1:
|
1297 |
+
original_mask = (original_mask * 255).astype(np.uint8)
|
1298 |
+
|
1299 |
+
masked_image = original_image * (1 - (original_mask>0))
|
1300 |
+
masked_image = masked_image.astype(original_image.dtype)
|
1301 |
+
masked_image = Image.fromarray(masked_image)
|
1302 |
+
|
1303 |
+
return [masked_image], [mask_image], original_mask, True
|
1304 |
+
|
1305 |
+
|
1306 |
+
def init_img(base,
|
1307 |
+
init_type,
|
1308 |
+
prompt,
|
1309 |
+
aspect_ratio,
|
1310 |
+
example_change_times
|
1311 |
+
):
|
1312 |
+
image_pil = base["background"].convert("RGB")
|
1313 |
+
original_image = np.array(image_pil)
|
1314 |
+
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
|
1315 |
+
raise gr.Error('image aspect ratio cannot be larger than 2.0')
|
1316 |
+
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
|
1317 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
|
1318 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
|
1319 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
|
1320 |
+
width, height = image_pil.size
|
1321 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1322 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1323 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1324 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1325 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1326 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1327 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1328 |
+
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
|
1329 |
+
else:
|
1330 |
+
if aspect_ratio not in ASPECT_RATIO_LABELS:
|
1331 |
+
aspect_ratio = "Custom resolution"
|
1332 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1333 |
+
|
1334 |
+
|
1335 |
+
def reset_func(input_image,
|
1336 |
+
original_image,
|
1337 |
+
original_mask,
|
1338 |
+
prompt,
|
1339 |
+
target_prompt,
|
1340 |
+
):
|
1341 |
+
input_image = None
|
1342 |
+
original_image = None
|
1343 |
+
original_mask = None
|
1344 |
+
prompt = ''
|
1345 |
+
mask_gallery = []
|
1346 |
+
masked_gallery = []
|
1347 |
+
result_gallery = []
|
1348 |
+
target_prompt = ''
|
1349 |
+
if torch.cuda.is_available():
|
1350 |
+
torch.cuda.empty_cache()
|
1351 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1352 |
+
|
1353 |
+
|
1354 |
+
def update_example(example_type,
|
1355 |
+
prompt,
|
1356 |
+
example_change_times):
|
1357 |
+
input_image = INPUT_IMAGE_PATH[example_type]
|
1358 |
+
image_pil = Image.open(input_image).convert("RGB")
|
1359 |
+
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
|
1360 |
+
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
|
1361 |
+
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
|
1362 |
+
width, height = image_pil.size
|
1363 |
+
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
|
1364 |
+
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
|
1365 |
+
image_pil = image_pil.resize((width_new, height_new))
|
1366 |
+
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
|
1367 |
+
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
|
1368 |
+
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
|
1369 |
+
|
1370 |
+
original_image = np.array(image_pil)
|
1371 |
+
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1372 |
+
aspect_ratio = "Custom resolution"
|
1373 |
+
example_change_times += 1
|
1374 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1375 |
+
|
1376 |
+
|
1377 |
+
def generate_target_prompt(input_image,
|
1378 |
+
original_image,
|
1379 |
+
prompt):
|
1380 |
+
# load example image
|
1381 |
+
if isinstance(original_image, str):
|
1382 |
+
original_image = input_image
|
1383 |
+
|
1384 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
1385 |
+
vlm_processor,
|
1386 |
+
vlm_model,
|
1387 |
+
original_image,
|
1388 |
+
prompt,
|
1389 |
+
device)
|
1390 |
+
return prompt_after_apply_instruction
|
1391 |
+
|
1392 |
+
|
1393 |
+
# 新增事件处理函数
|
1394 |
+
def generate_blip_description(input_image):
|
1395 |
+
if input_image is None:
|
1396 |
+
return "", "Input image cannot be None"
|
1397 |
+
from app.utils.utils import generate_caption
|
1398 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
1399 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device)
|
1400 |
+
try:
|
1401 |
+
image_pil = input_image["background"].convert("RGB")
|
1402 |
+
except KeyError:
|
1403 |
+
return "", "Input image missing 'background' key"
|
1404 |
+
except AttributeError as e:
|
1405 |
+
return "", f"Invalid image object: {str(e)}"
|
1406 |
+
try:
|
1407 |
+
description = generate_caption(blip_processor, blip_model, image_pil, device)
|
1408 |
+
return description, description # 同时更新state和显示组件
|
1409 |
+
except Exception as e:
|
1410 |
+
return "", f"Caption generation failed: {str(e)}"
|
1411 |
+
|
1412 |
+
|
1413 |
+
def submit_GPT4o_KEY(GPT4o_KEY):
|
1414 |
+
global vlm_model, vlm_processor
|
1415 |
+
if vlm_model is not None:
|
1416 |
+
del vlm_model
|
1417 |
+
torch.cuda.empty_cache()
|
1418 |
+
try:
|
1419 |
+
vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
|
1420 |
+
vlm_processor = ""
|
1421 |
+
response = vlm_model.chat.completions.create(
|
1422 |
+
model="deepseek-chat",
|
1423 |
+
messages=[
|
1424 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1425 |
+
{"role": "user", "content": "Hello."}
|
1426 |
+
]
|
1427 |
+
)
|
1428 |
+
response_str = response.choices[0].message.content
|
1429 |
+
|
1430 |
+
return "Success. " + response_str, "GPT4-o (Highly Recommended)"
|
1431 |
+
except Exception as e:
|
1432 |
+
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
|
1433 |
+
|
1434 |
+
|
1435 |
+
|
1436 |
+
def verify_deepseek_api():
|
1437 |
+
try:
|
1438 |
+
response = llm_model.chat.completions.create(
|
1439 |
+
model="deepseek-chat",
|
1440 |
+
messages=[
|
1441 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1442 |
+
{"role": "user", "content": "Hello."}
|
1443 |
+
]
|
1444 |
+
)
|
1445 |
+
response_str = response.choices[0].message.content
|
1446 |
+
|
1447 |
+
return True, "Success. " + response_str
|
1448 |
+
|
1449 |
+
except Exception as e:
|
1450 |
+
return False, "Invalid DeepSeek API Key"
|
1451 |
+
|
1452 |
+
|
1453 |
+
def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
|
1454 |
+
try:
|
1455 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
1456 |
+
response = llm_model.chat.completions.create(
|
1457 |
+
model="deepseek-chat",
|
1458 |
+
messages=messages
|
1459 |
+
)
|
1460 |
+
response_str = response.choices[0].message.content
|
1461 |
+
return response_str
|
1462 |
+
except Exception as e:
|
1463 |
+
raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1464 |
+
|
1465 |
+
|
1466 |
+
def llm_decomposed_prompt_after_apply_instruction(integrated_query):
|
1467 |
+
try:
|
1468 |
+
messages = create_decomposed_query_messages_deepseek(integrated_query)
|
1469 |
+
response = llm_model.chat.completions.create(
|
1470 |
+
model="deepseek-chat",
|
1471 |
+
messages=messages
|
1472 |
+
)
|
1473 |
+
response_str = response.choices[0].message.content
|
1474 |
+
return response_str
|
1475 |
+
except Exception as e:
|
1476 |
+
raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
|
1477 |
+
|
1478 |
+
|
1479 |
+
def enhance_description(blip_description, prompt):
|
1480 |
+
try:
|
1481 |
+
if not prompt or not blip_description:
|
1482 |
+
print("Empty prompt or blip_description detected")
|
1483 |
+
return "", ""
|
1484 |
+
|
1485 |
+
print(f"Enhancing with prompt: {prompt}")
|
1486 |
+
enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt)
|
1487 |
+
return enhanced_description, enhanced_description
|
1488 |
+
|
1489 |
+
except Exception as e:
|
1490 |
+
print(f"Enhancement failed: {str(e)}")
|
1491 |
+
return "Error occurred", "Error occurred"
|
1492 |
+
|
1493 |
+
def decompose_description(enhanced_description):
|
1494 |
+
try:
|
1495 |
+
if not enhanced_description:
|
1496 |
+
print("Empty enhanced_description detected")
|
1497 |
+
return "", ""
|
1498 |
+
|
1499 |
+
print(f"Decomposing the enhanced description: {enhanced_description}")
|
1500 |
+
decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
|
1501 |
+
return decomposed_description, decomposed_description
|
1502 |
+
|
1503 |
+
except Exception as e:
|
1504 |
+
print(f"Decomposition failed: {str(e)}")
|
1505 |
+
return "Error occurred", "Error occurred"
|
1506 |
+
|
1507 |
+
|
1508 |
+
block = gr.Blocks(
|
1509 |
+
theme=gr.themes.Soft(
|
1510 |
+
radius_size=gr.themes.sizes.radius_none,
|
1511 |
+
text_size=gr.themes.sizes.text_md
|
1512 |
+
)
|
1513 |
+
)
|
1514 |
+
with block as demo:
|
1515 |
+
with gr.Row():
|
1516 |
+
with gr.Column():
|
1517 |
+
gr.HTML(head)
|
1518 |
+
|
1519 |
+
gr.Markdown(descriptions)
|
1520 |
+
|
1521 |
+
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
|
1522 |
+
with gr.Row(equal_height=True):
|
1523 |
+
gr.Markdown(instructions)
|
1524 |
+
|
1525 |
+
original_image = gr.State(value=None)
|
1526 |
+
original_mask = gr.State(value=None)
|
1527 |
+
category = gr.State(value=None)
|
1528 |
+
status = gr.State(value=None)
|
1529 |
+
invert_mask_state = gr.State(value=False)
|
1530 |
+
example_change_times = gr.State(value=0)
|
1531 |
+
deepseek_verified = gr.State(value=False)
|
1532 |
+
blip_description = gr.State(value="")
|
1533 |
+
enhanced_description = gr.State(value="")
|
1534 |
+
decomposed_description = gr.State(value="")
|
1535 |
+
|
1536 |
+
with gr.Row():
|
1537 |
+
with gr.Column():
|
1538 |
+
with gr.Row():
|
1539 |
+
input_image = gr.ImageEditor(
|
1540 |
+
label="参考图像",
|
1541 |
+
type="pil",
|
1542 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
|
1543 |
+
layers = False,
|
1544 |
+
interactive=True,
|
1545 |
+
# height=1024,
|
1546 |
+
height=512,
|
1547 |
+
sources=["upload"],
|
1548 |
+
placeholder="🫧 点击此处或下面的图标上传图像 🫧",
|
1549 |
+
)
|
1550 |
+
|
1551 |
+
prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1)
|
1552 |
+
run_button = gr.Button("💫 图像编辑")
|
1553 |
+
|
1554 |
+
vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1555 |
+
with gr.Group():
|
1556 |
+
with gr.Row():
|
1557 |
+
# GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
|
1558 |
+
GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1559 |
+
GPT4o_KEY_submit = gr.Button("🙈 验证")
|
1560 |
+
|
1561 |
+
|
1562 |
+
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1563 |
+
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1564 |
+
|
1565 |
+
with gr.Row():
|
1566 |
+
mask_button = gr.Button("💎 掩膜生成")
|
1567 |
+
random_mask_button = gr.Button("Square/Circle Mask ")
|
1568 |
+
|
1569 |
+
|
1570 |
+
with gr.Row():
|
1571 |
+
generate_target_prompt_button = gr.Button("Generate Target Prompt")
|
1572 |
+
|
1573 |
+
target_prompt = gr.Text(
|
1574 |
+
label="Input Target Prompt",
|
1575 |
+
max_lines=5,
|
1576 |
+
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
|
1577 |
+
value='',
|
1578 |
+
lines=2
|
1579 |
+
)
|
1580 |
+
|
1581 |
+
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
|
1582 |
+
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
|
1583 |
+
negative_prompt = gr.Text(
|
1584 |
+
label="Negative Prompt",
|
1585 |
+
max_lines=5,
|
1586 |
+
placeholder="Please input your negative prompt",
|
1587 |
+
value='ugly, low quality',lines=1
|
1588 |
+
)
|
1589 |
+
|
1590 |
+
control_strength = gr.Slider(
|
1591 |
+
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
|
1592 |
+
)
|
1593 |
+
with gr.Group():
|
1594 |
+
seed = gr.Slider(
|
1595 |
+
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
|
1596 |
+
)
|
1597 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
1598 |
+
|
1599 |
+
blending = gr.Checkbox(label="Blending mode", value=True)
|
1600 |
+
|
1601 |
+
|
1602 |
+
num_samples = gr.Slider(
|
1603 |
+
label="Num samples", minimum=0, maximum=4, step=1, value=4
|
1604 |
+
)
|
1605 |
+
|
1606 |
+
with gr.Group():
|
1607 |
+
with gr.Row():
|
1608 |
+
guidance_scale = gr.Slider(
|
1609 |
+
label="Guidance scale",
|
1610 |
+
minimum=1,
|
1611 |
+
maximum=12,
|
1612 |
+
step=0.1,
|
1613 |
+
value=7.5,
|
1614 |
+
)
|
1615 |
+
num_inference_steps = gr.Slider(
|
1616 |
+
label="Number of inference steps",
|
1617 |
+
minimum=1,
|
1618 |
+
maximum=50,
|
1619 |
+
step=1,
|
1620 |
+
value=50,
|
1621 |
+
)
|
1622 |
+
|
1623 |
+
|
1624 |
+
with gr.Group(visible=True):
|
1625 |
+
# BLIP生成的描述
|
1626 |
+
blip_output = gr.Textbox(label="原图描述", placeholder="💬 BLIP生成的图像基础描述 💬", interactive=True, lines=3)
|
1627 |
+
# DeepSeek API验证
|
1628 |
+
with gr.Row():
|
1629 |
+
deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1)
|
1630 |
+
verify_deepseek = gr.Button("🙈 验证")
|
1631 |
+
# 整合后的描述区域
|
1632 |
+
with gr.Row():
|
1633 |
+
enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3)
|
1634 |
+
enhance_button = gr.Button("✨ 整合")
|
1635 |
+
# 分解后的描述区域
|
1636 |
+
with gr.Row():
|
1637 |
+
decomposed_output = gr.Textbox(label="描述分解", placeholder="🔍 DeepSeek生成的分解描述 🔍", interactive=True, lines=3)
|
1638 |
+
decompose_button = gr.Button("🔧 分解")
|
1639 |
+
with gr.Row():
|
1640 |
+
with gr.Tab(elem_classes="feedback", label="Masked Image"):
|
1641 |
+
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
|
1642 |
+
with gr.Tab(elem_classes="feedback", label="Mask"):
|
1643 |
+
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
|
1644 |
+
|
1645 |
+
invert_mask_button = gr.Button("Invert Mask")
|
1646 |
+
dilation_size = gr.Slider(
|
1647 |
+
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
|
1648 |
+
)
|
1649 |
+
with gr.Row():
|
1650 |
+
dilation_mask_button = gr.Button("Dilation Generated Mask")
|
1651 |
+
erosion_mask_button = gr.Button("Erosion Generated Mask")
|
1652 |
+
|
1653 |
+
moving_pixels = gr.Slider(
|
1654 |
+
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
|
1655 |
+
)
|
1656 |
+
with gr.Row():
|
1657 |
+
move_left_button = gr.Button("Move Left")
|
1658 |
+
move_right_button = gr.Button("Move Right")
|
1659 |
+
with gr.Row():
|
1660 |
+
move_up_button = gr.Button("Move Up")
|
1661 |
+
move_down_button = gr.Button("Move Down")
|
1662 |
+
|
1663 |
+
with gr.Tab(elem_classes="feedback", label="Output"):
|
1664 |
+
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1665 |
+
|
1666 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1667 |
+
|
1668 |
+
reset_button = gr.Button("Reset")
|
1669 |
+
|
1670 |
+
init_type = gr.Textbox(label="Init Name", value="", visible=False)
|
1671 |
+
example_type = gr.Textbox(label="Example Name", value="", visible=False)
|
1672 |
+
|
1673 |
+
|
1674 |
+
|
1675 |
+
with gr.Row():
|
1676 |
+
example = gr.Examples(
|
1677 |
+
label="Quick Example",
|
1678 |
+
examples=EXAMPLES,
|
1679 |
+
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
|
1680 |
+
examples_per_page=10,
|
1681 |
+
cache_examples=False,
|
1682 |
+
)
|
1683 |
+
|
1684 |
+
|
1685 |
+
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
|
1686 |
+
with gr.Row(equal_height=True):
|
1687 |
+
gr.Markdown(tips)
|
1688 |
+
|
1689 |
+
with gr.Row():
|
1690 |
+
gr.Markdown(citation)
|
1691 |
+
|
1692 |
+
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
|
1693 |
+
## And we need to solve the conflict between the upload and change example functions.
|
1694 |
+
input_image.upload(
|
1695 |
+
init_img,
|
1696 |
+
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1697 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1698 |
+
)
|
1699 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1700 |
+
|
1701 |
+
## vlm and base model dropdown
|
1702 |
+
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
1703 |
+
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
|
1704 |
+
|
1705 |
+
|
1706 |
+
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
|
1707 |
+
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
|
1708 |
+
|
1709 |
+
|
1710 |
+
ips=[input_image,
|
1711 |
+
original_image,
|
1712 |
+
original_mask,
|
1713 |
+
prompt,
|
1714 |
+
negative_prompt,
|
1715 |
+
control_strength,
|
1716 |
+
seed,
|
1717 |
+
randomize_seed,
|
1718 |
+
guidance_scale,
|
1719 |
+
num_inference_steps,
|
1720 |
+
num_samples,
|
1721 |
+
blending,
|
1722 |
+
category,
|
1723 |
+
target_prompt,
|
1724 |
+
resize_default,
|
1725 |
+
aspect_ratio,
|
1726 |
+
invert_mask_state]
|
1727 |
+
|
1728 |
+
## run brushedit
|
1729 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1730 |
+
|
1731 |
+
## mask func
|
1732 |
+
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
1733 |
+
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1734 |
+
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1735 |
+
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
|
1736 |
+
|
1737 |
+
## move mask func
|
1738 |
+
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1739 |
+
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1740 |
+
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1741 |
+
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1742 |
+
|
1743 |
+
## prompt func
|
1744 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1745 |
+
|
1746 |
+
## reset func
|
1747 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1748 |
+
|
1749 |
+
|
1750 |
+
# 绑定事件处理
|
1751 |
+
input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output])
|
1752 |
+
verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
|
1753 |
+
enhance_button.click(fn=enhance_description, inputs=[blip_description, prompt], outputs=[enhanced_description, enhanced_output])
|
1754 |
+
decompose_button.click(fn=decompose_description, inputs=[enhanced_description], outputs=[decomposed_description, decomposed_output])
|
1755 |
+
|
1756 |
+
demo.launch(server_name="0.0.0.0", server_port=12345, share=True)
|
1757 |
+
|
1758 |
+
|
llm_pipeline.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from openai import OpenAI
|
3 |
+
from app.deepseek.instructions import create_apply_editing_messages_deepseek
|
4 |
+
|
5 |
+
|
6 |
+
def run_deepseek_llm_inference(llm_model, messages):
|
7 |
+
response = llm_model.chat.completions.create(
|
8 |
+
model="deepseek-chat",
|
9 |
+
messages=messages
|
10 |
+
)
|
11 |
+
response_str = response.choices[0].message.content
|
12 |
+
return response_str
|
13 |
+
|
14 |
+
|
15 |
+
from openai import AuthenticationError, APIConnectionError, RateLimitError, BadRequestError, APIError
|
16 |
+
|
17 |
+
def llm_response_prompt_after_apply_instruction(image_caption, editing_prompt):
|
18 |
+
try:
|
19 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
20 |
+
response_str = run_deepseek_llm_inference(llm_model, messages)
|
21 |
+
return response_str
|
22 |
+
except AuthenticationError as e:
|
23 |
+
raise gr.Error(f"认证失败: 请检查API密钥是否正确 (错误详情: {e.message})")
|
24 |
+
except APIConnectionError as e:
|
25 |
+
raise gr.Error(f"连接异常: 请检查网络连接后重试 (错误详情: {e.message})")
|
26 |
+
except RateLimitError as e:
|
27 |
+
raise gr.Error(f"请求超限: 请稍后重试 (错误详情: {e.message})")
|
28 |
+
except BadRequestError as e:
|
29 |
+
if "model" in e.message.lower():
|
30 |
+
raise gr.Error(f"模型错误: 请检查模型名称是否正确 (错误详情: {e.message})")
|
31 |
+
raise gr.Error(f"无效请求: 请检查输入参数 (错误详情: {e.message})")
|
32 |
+
except APIError as e:
|
33 |
+
raise gr.Error(f"API异常: 服务端返回错误 (错误详情: {e.message})")
|
34 |
+
except Exception as e:
|
35 |
+
raise gr.Error(f"未预期错误: {str(e)},请检查控制台日志获取详细信息")
|
llm_template.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from openai import OpenAI
|
5 |
+
|
6 |
+
## init device
|
7 |
+
device = "cpu"
|
8 |
+
torch_dtype = torch.float16
|
9 |
+
|
10 |
+
|
11 |
+
llms_list = [
|
12 |
+
{
|
13 |
+
"type": "deepseek",
|
14 |
+
"name": "deepseek",
|
15 |
+
"local_path": "",
|
16 |
+
"processor": "",
|
17 |
+
"model": ""
|
18 |
+
},
|
19 |
+
]
|
20 |
+
|
21 |
+
llms_template = {k["name"]: (k["type"], k["local_path"], k["processor"], k["model"]) for k in llms_list}
|
vlm_pipeline.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import numpy as np
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from openai import OpenAI
|
11 |
+
from transformers import (LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration)
|
12 |
+
|
13 |
+
|
14 |
+
from qwen_vl_utils import process_vision_info
|
15 |
+
|
16 |
+
from app.gpt4_o.instructions import (
|
17 |
+
create_editing_category_messages_gpt4o,
|
18 |
+
create_ori_object_messages_gpt4o,
|
19 |
+
create_add_object_messages_gpt4o,
|
20 |
+
create_apply_editing_messages_gpt4o)
|
21 |
+
|
22 |
+
from app.llava.instructions import (
|
23 |
+
create_editing_category_messages_llava,
|
24 |
+
create_ori_object_messages_llava,
|
25 |
+
create_add_object_messages_llava,
|
26 |
+
create_apply_editing_messages_llava)
|
27 |
+
|
28 |
+
from app.qwen2.instructions import (
|
29 |
+
create_editing_category_messages_qwen2,
|
30 |
+
create_ori_object_messages_qwen2,
|
31 |
+
create_add_object_messages_qwen2,
|
32 |
+
create_apply_editing_messages_qwen2)
|
33 |
+
|
34 |
+
from app.deepseek.instructions import (
|
35 |
+
create_editing_category_messages_deepseek,
|
36 |
+
create_ori_object_messages_deepseek,
|
37 |
+
create_apply_editing_messages_deepseek
|
38 |
+
)
|
39 |
+
|
40 |
+
from app.utils.utils import run_grounded_sam
|
41 |
+
|
42 |
+
|
43 |
+
def encode_image(img):
|
44 |
+
img = Image.fromarray(img.astype('uint8'))
|
45 |
+
buffered = BytesIO()
|
46 |
+
img.save(buffered, format="PNG")
|
47 |
+
img_bytes = buffered.getvalue()
|
48 |
+
return base64.b64encode(img_bytes).decode('utf-8')
|
49 |
+
|
50 |
+
|
51 |
+
def run_gpt4o_vl_inference(vlm_model,
|
52 |
+
messages):
|
53 |
+
response = vlm_model.chat.completions.create(
|
54 |
+
model="gpt-4o-2024-08-06",
|
55 |
+
messages=messages
|
56 |
+
)
|
57 |
+
response_str = response.choices[0].message.content
|
58 |
+
return response_str
|
59 |
+
|
60 |
+
|
61 |
+
def run_deepseek_inference(llm_model,
|
62 |
+
messages):
|
63 |
+
try:
|
64 |
+
response = llm_model.chat.completions.create(
|
65 |
+
model="deepseek-chat",
|
66 |
+
messages=messages
|
67 |
+
)
|
68 |
+
response_str = response.choices[0].message.content
|
69 |
+
return response_str
|
70 |
+
except Exception as e:
|
71 |
+
return "Invalid DeepSeek API Key"
|
72 |
+
|
73 |
+
|
74 |
+
def run_llava_next_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
|
75 |
+
prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
|
76 |
+
inputs = vlm_processor(images=image, text=prompt, return_tensors="pt").to(device)
|
77 |
+
output = vlm_model.generate(**inputs, max_new_tokens=200)
|
78 |
+
generated_ids_trimmed = [
|
79 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
|
80 |
+
]
|
81 |
+
response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
|
82 |
+
|
83 |
+
return response_str
|
84 |
+
|
85 |
+
|
86 |
+
def run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
|
87 |
+
text = vlm_processor.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
91 |
+
inputs = vlm_processor(
|
92 |
+
text=[text],
|
93 |
+
images=image_inputs,
|
94 |
+
videos=video_inputs,
|
95 |
+
padding=True,
|
96 |
+
return_tensors="pt",
|
97 |
+
)
|
98 |
+
inputs = inputs.to(device)
|
99 |
+
generated_ids = vlm_model.generate(**inputs, max_new_tokens=128)
|
100 |
+
generated_ids_trimmed = [
|
101 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
102 |
+
]
|
103 |
+
response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
|
104 |
+
return response_str
|
105 |
+
|
106 |
+
|
107 |
+
### response editing type
|
108 |
+
def vlm_response_editing_type(vlm_processor,
|
109 |
+
vlm_model,
|
110 |
+
llm_model,
|
111 |
+
image,
|
112 |
+
image_caption,
|
113 |
+
editing_prompt,
|
114 |
+
device):
|
115 |
+
|
116 |
+
if isinstance(vlm_model, OpenAI):
|
117 |
+
messages = create_editing_category_messages_gpt4o(editing_prompt)
|
118 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
119 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
120 |
+
messages = create_editing_category_messages_llava(editing_prompt)
|
121 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device=device)
|
122 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
123 |
+
# messages = create_editing_category_messages_qwen2(editing_prompt)
|
124 |
+
messages = create_editing_category_messages_qwen2(image_caption, editing_prompt)
|
125 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
|
126 |
+
# messages = create_editing_category_messages_deepseek(image_caption, editing_prompt)
|
127 |
+
# response_str = run_deepseek_inference(llm_model, messages)
|
128 |
+
|
129 |
+
try:
|
130 |
+
for category_name in ["Addition","Remove","Local","Global","Background"]:
|
131 |
+
if category_name.lower() in response_str.lower():
|
132 |
+
return category_name
|
133 |
+
except Exception as e:
|
134 |
+
raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
|
135 |
+
|
136 |
+
|
137 |
+
### response object to be edited
|
138 |
+
def vlm_response_object_wait_for_edit(vlm_processor,
|
139 |
+
vlm_model,
|
140 |
+
llm_model,
|
141 |
+
image,
|
142 |
+
image_caption,
|
143 |
+
category,
|
144 |
+
editing_prompt,
|
145 |
+
device):
|
146 |
+
if category in ["Background", "Global", "Addition"]:
|
147 |
+
edit_object = "nan"
|
148 |
+
return edit_object
|
149 |
+
|
150 |
+
if isinstance(vlm_model, OpenAI):
|
151 |
+
messages = create_ori_object_messages_gpt4o(editing_prompt)
|
152 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
153 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
154 |
+
messages = create_ori_object_messages_llava(editing_prompt)
|
155 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
156 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
157 |
+
# messages = create_ori_object_messages_qwen2(editing_prompt)
|
158 |
+
messages = create_ori_object_messages_qwen2(image_caption, editing_prompt)
|
159 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
160 |
+
# messages = create_ori_object_messages_deepseek(image_caption, editing_prompt)
|
161 |
+
# response_str = run_deepseek_inference(llm_model, messages)
|
162 |
+
return response_str
|
163 |
+
|
164 |
+
|
165 |
+
### response mask
|
166 |
+
def vlm_response_mask(vlm_processor,
|
167 |
+
vlm_model,
|
168 |
+
category,
|
169 |
+
image,
|
170 |
+
editing_prompt,
|
171 |
+
object_wait_for_edit,
|
172 |
+
sam=None,
|
173 |
+
sam_predictor=None,
|
174 |
+
sam_automask_generator=None,
|
175 |
+
groundingdino_model=None,
|
176 |
+
device=None,
|
177 |
+
):
|
178 |
+
mask = None
|
179 |
+
if editing_prompt is None or len(editing_prompt)==0:
|
180 |
+
raise gr.Error("Please input the editing instruction!")
|
181 |
+
height, width = image.shape[:2]
|
182 |
+
if category=="Addition":
|
183 |
+
try:
|
184 |
+
if isinstance(vlm_model, OpenAI):
|
185 |
+
base64_image = encode_image(image)
|
186 |
+
messages = create_add_object_messages_gpt4o(editing_prompt, base64_image, height=height, width=width)
|
187 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
188 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
189 |
+
messages = create_add_object_messages_llava(editing_prompt, height=height, width=width)
|
190 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
191 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
192 |
+
base64_image = encode_image(image)
|
193 |
+
messages = create_add_object_messages_qwen2(editing_prompt, base64_image, height=height, width=width)
|
194 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
195 |
+
pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
|
196 |
+
box = re.findall(pattern, response_str)
|
197 |
+
box = box[0][1:-1].split(",")
|
198 |
+
for i in range(len(box)):
|
199 |
+
box[i] = int(box[i])
|
200 |
+
cus_mask = np.zeros((height, width))
|
201 |
+
cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
|
202 |
+
mask = cus_mask
|
203 |
+
except:
|
204 |
+
raise gr.Error("Please set the mask manually, currently the VLM cannot output the mask!")
|
205 |
+
|
206 |
+
elif category=="Background":
|
207 |
+
labels = "background"
|
208 |
+
elif category=="Global":
|
209 |
+
mask = 255 * np.zeros((height, width))
|
210 |
+
else:
|
211 |
+
labels = object_wait_for_edit
|
212 |
+
|
213 |
+
if mask is None:
|
214 |
+
for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
|
215 |
+
try:
|
216 |
+
detections = run_grounded_sam(
|
217 |
+
input_image={"image":Image.fromarray(image.astype('uint8')),
|
218 |
+
"mask":None},
|
219 |
+
text_prompt=labels,
|
220 |
+
task_type="seg",
|
221 |
+
box_threshold=thresh,
|
222 |
+
text_threshold=0.25,
|
223 |
+
# iou_threshold=0.5,
|
224 |
+
# scribble_mode="split",
|
225 |
+
sam=sam,
|
226 |
+
sam_predictor=sam_predictor,
|
227 |
+
# sam_automask_generator=sam_automask_generator,
|
228 |
+
groundingdino_model=groundingdino_model,
|
229 |
+
device=device,
|
230 |
+
)
|
231 |
+
mask = np.array(detections[0,0,...].cpu()) * 255
|
232 |
+
break
|
233 |
+
except:
|
234 |
+
print(f"wrong in threshhold: {thresh}, continue")
|
235 |
+
continue
|
236 |
+
return mask
|
237 |
+
|
238 |
+
|
239 |
+
def vlm_response_prompt_after_apply_instruction(vlm_processor,
|
240 |
+
vlm_model,
|
241 |
+
llm_model,
|
242 |
+
image,
|
243 |
+
image_caption,
|
244 |
+
editing_prompt,
|
245 |
+
device):
|
246 |
+
|
247 |
+
try:
|
248 |
+
if isinstance(vlm_model, OpenAI):
|
249 |
+
base64_image = encode_image(image)
|
250 |
+
messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
|
251 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
252 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
253 |
+
messages = create_apply_editing_messages_llava(editing_prompt)
|
254 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
255 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
256 |
+
# base64_image = encode_image(image)
|
257 |
+
# messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
|
258 |
+
messages = create_apply_editing_messages_qwen2(image_caption, editing_prompt)
|
259 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
260 |
+
# messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
261 |
+
# response_str = run_deepseek_inference(llm_model, messages)
|
262 |
+
else:
|
263 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
264 |
+
except Exception as e:
|
265 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
266 |
+
return response_str
|
vlm_pipeline_noqwen.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import numpy as np
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from openai import OpenAI
|
11 |
+
from transformers import (LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration)
|
12 |
+
|
13 |
+
|
14 |
+
from qwen_vl_utils import process_vision_info
|
15 |
+
|
16 |
+
from app.gpt4_o.instructions import (
|
17 |
+
create_editing_category_messages_gpt4o,
|
18 |
+
create_ori_object_messages_gpt4o,
|
19 |
+
create_add_object_messages_gpt4o,
|
20 |
+
create_apply_editing_messages_gpt4o)
|
21 |
+
|
22 |
+
from app.llava.instructions import (
|
23 |
+
create_editing_category_messages_llava,
|
24 |
+
create_ori_object_messages_llava,
|
25 |
+
create_add_object_messages_llava,
|
26 |
+
create_apply_editing_messages_llava)
|
27 |
+
|
28 |
+
from app.qwen2.instructions import (
|
29 |
+
create_editing_category_messages_qwen2,
|
30 |
+
create_ori_object_messages_qwen2,
|
31 |
+
create_add_object_messages_qwen2,
|
32 |
+
create_apply_editing_messages_qwen2)
|
33 |
+
|
34 |
+
from app.deepseek.instructions import (
|
35 |
+
create_editing_category_messages_deepseek,
|
36 |
+
create_ori_object_messages_deepseek,
|
37 |
+
create_apply_editing_messages_deepseek
|
38 |
+
)
|
39 |
+
|
40 |
+
from app.utils.utils import run_grounded_sam
|
41 |
+
|
42 |
+
|
43 |
+
def encode_image(img):
|
44 |
+
img = Image.fromarray(img.astype('uint8'))
|
45 |
+
buffered = BytesIO()
|
46 |
+
img.save(buffered, format="PNG")
|
47 |
+
img_bytes = buffered.getvalue()
|
48 |
+
return base64.b64encode(img_bytes).decode('utf-8')
|
49 |
+
|
50 |
+
|
51 |
+
def run_gpt4o_vl_inference(vlm_model,
|
52 |
+
messages):
|
53 |
+
response = vlm_model.chat.completions.create(
|
54 |
+
model="gpt-4o-2024-08-06",
|
55 |
+
messages=messages
|
56 |
+
)
|
57 |
+
response_str = response.choices[0].message.content
|
58 |
+
return response_str
|
59 |
+
|
60 |
+
|
61 |
+
def run_deepseek_inference(llm_model,
|
62 |
+
messages):
|
63 |
+
try:
|
64 |
+
response = llm_model.chat.completions.create(
|
65 |
+
model="deepseek-chat",
|
66 |
+
messages=messages
|
67 |
+
)
|
68 |
+
response_str = response.choices[0].message.content
|
69 |
+
return response_str
|
70 |
+
except Exception as e:
|
71 |
+
return "Invalid DeepSeek API Key"
|
72 |
+
|
73 |
+
|
74 |
+
def run_llava_next_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
|
75 |
+
prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
|
76 |
+
inputs = vlm_processor(images=image, text=prompt, return_tensors="pt").to(device)
|
77 |
+
output = vlm_model.generate(**inputs, max_new_tokens=200)
|
78 |
+
generated_ids_trimmed = [
|
79 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
|
80 |
+
]
|
81 |
+
response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
|
82 |
+
|
83 |
+
return response_str
|
84 |
+
|
85 |
+
|
86 |
+
def run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
|
87 |
+
text = vlm_processor.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
91 |
+
inputs = vlm_processor(
|
92 |
+
text=[text],
|
93 |
+
images=image_inputs,
|
94 |
+
videos=video_inputs,
|
95 |
+
padding=True,
|
96 |
+
return_tensors="pt",
|
97 |
+
)
|
98 |
+
inputs = inputs.to(device)
|
99 |
+
generated_ids = vlm_model.generate(**inputs, max_new_tokens=128)
|
100 |
+
generated_ids_trimmed = [
|
101 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
102 |
+
]
|
103 |
+
response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
|
104 |
+
return response_str
|
105 |
+
|
106 |
+
|
107 |
+
### response editing type
|
108 |
+
def vlm_response_editing_type(vlm_processor,
|
109 |
+
vlm_model,
|
110 |
+
llm_model,
|
111 |
+
image,
|
112 |
+
image_caption,
|
113 |
+
editing_prompt,
|
114 |
+
device):
|
115 |
+
|
116 |
+
if isinstance(vlm_model, OpenAI):
|
117 |
+
messages = create_editing_category_messages_gpt4o(editing_prompt)
|
118 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
119 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
120 |
+
messages = create_editing_category_messages_llava(editing_prompt)
|
121 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device=device)
|
122 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
123 |
+
# messages = create_editing_category_messages_qwen2(editing_prompt)
|
124 |
+
# response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
|
125 |
+
messages = create_editing_category_messages_deepseek(image_caption, editing_prompt)
|
126 |
+
response_str = run_deepseek_inference(llm_model, messages)
|
127 |
+
|
128 |
+
try:
|
129 |
+
for category_name in ["Addition","Remove","Local","Global","Background"]:
|
130 |
+
if category_name.lower() in response_str.lower():
|
131 |
+
return category_name
|
132 |
+
except Exception as e:
|
133 |
+
raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
|
134 |
+
|
135 |
+
|
136 |
+
### response object to be edited
|
137 |
+
def vlm_response_object_wait_for_edit(vlm_processor,
|
138 |
+
vlm_model,
|
139 |
+
llm_model,
|
140 |
+
image,
|
141 |
+
image_caption,
|
142 |
+
category,
|
143 |
+
editing_prompt,
|
144 |
+
device):
|
145 |
+
if category in ["Background", "Global", "Addition"]:
|
146 |
+
edit_object = "nan"
|
147 |
+
return edit_object
|
148 |
+
|
149 |
+
if isinstance(vlm_model, OpenAI):
|
150 |
+
messages = create_ori_object_messages_gpt4o(editing_prompt)
|
151 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
152 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
153 |
+
messages = create_ori_object_messages_llava(editing_prompt)
|
154 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
155 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
156 |
+
# messages = create_ori_object_messages_qwen2(editing_prompt)
|
157 |
+
# response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
158 |
+
messages = create_ori_object_messages_deepseek(image_caption, editing_prompt)
|
159 |
+
response_str = run_deepseek_inference(llm_model, messages)
|
160 |
+
return response_str
|
161 |
+
|
162 |
+
|
163 |
+
### response mask
|
164 |
+
def vlm_response_mask(vlm_processor,
|
165 |
+
vlm_model,
|
166 |
+
category,
|
167 |
+
image,
|
168 |
+
editing_prompt,
|
169 |
+
object_wait_for_edit,
|
170 |
+
sam=None,
|
171 |
+
sam_predictor=None,
|
172 |
+
sam_automask_generator=None,
|
173 |
+
groundingdino_model=None,
|
174 |
+
device=None,
|
175 |
+
):
|
176 |
+
mask = None
|
177 |
+
if editing_prompt is None or len(editing_prompt)==0:
|
178 |
+
raise gr.Error("Please input the editing instruction!")
|
179 |
+
height, width = image.shape[:2]
|
180 |
+
if category=="Addition":
|
181 |
+
try:
|
182 |
+
if isinstance(vlm_model, OpenAI):
|
183 |
+
base64_image = encode_image(image)
|
184 |
+
messages = create_add_object_messages_gpt4o(editing_prompt, base64_image, height=height, width=width)
|
185 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
186 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
187 |
+
messages = create_add_object_messages_llava(editing_prompt, height=height, width=width)
|
188 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
189 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
190 |
+
base64_image = encode_image(image)
|
191 |
+
messages = create_add_object_messages_qwen2(editing_prompt, base64_image, height=height, width=width)
|
192 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
193 |
+
pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
|
194 |
+
box = re.findall(pattern, response_str)
|
195 |
+
box = box[0][1:-1].split(",")
|
196 |
+
for i in range(len(box)):
|
197 |
+
box[i] = int(box[i])
|
198 |
+
cus_mask = np.zeros((height, width))
|
199 |
+
cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
|
200 |
+
mask = cus_mask
|
201 |
+
except:
|
202 |
+
raise gr.Error("Please set the mask manually, currently the VLM cannot output the mask!")
|
203 |
+
|
204 |
+
elif category=="Background":
|
205 |
+
labels = "background"
|
206 |
+
elif category=="Global":
|
207 |
+
mask = 255 * np.zeros((height, width))
|
208 |
+
else:
|
209 |
+
labels = object_wait_for_edit
|
210 |
+
|
211 |
+
if mask is None:
|
212 |
+
for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
|
213 |
+
try:
|
214 |
+
detections = run_grounded_sam(
|
215 |
+
input_image={"image":Image.fromarray(image.astype('uint8')),
|
216 |
+
"mask":None},
|
217 |
+
text_prompt=labels,
|
218 |
+
task_type="seg",
|
219 |
+
box_threshold=thresh,
|
220 |
+
text_threshold=0.25,
|
221 |
+
# iou_threshold=0.5,
|
222 |
+
# scribble_mode="split",
|
223 |
+
sam=sam,
|
224 |
+
sam_predictor=sam_predictor,
|
225 |
+
# sam_automask_generator=sam_automask_generator,
|
226 |
+
groundingdino_model=groundingdino_model,
|
227 |
+
device=device,
|
228 |
+
)
|
229 |
+
mask = np.array(detections[0,0,...].cpu()) * 255
|
230 |
+
break
|
231 |
+
except:
|
232 |
+
print(f"wrong in threshhold: {thresh}, continue")
|
233 |
+
continue
|
234 |
+
return mask
|
235 |
+
|
236 |
+
|
237 |
+
def vlm_response_prompt_after_apply_instruction(vlm_processor,
|
238 |
+
vlm_model,
|
239 |
+
llm_model,
|
240 |
+
image,
|
241 |
+
image_caption,
|
242 |
+
editing_prompt,
|
243 |
+
device):
|
244 |
+
|
245 |
+
try:
|
246 |
+
if isinstance(vlm_model, OpenAI):
|
247 |
+
base64_image = encode_image(image)
|
248 |
+
messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
|
249 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
250 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
251 |
+
messages = create_apply_editing_messages_llava(editing_prompt)
|
252 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
253 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
254 |
+
# base64_image = encode_image(image)
|
255 |
+
# messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
|
256 |
+
# response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
257 |
+
messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
|
258 |
+
response_str = run_deepseek_inference(llm_model, messages)
|
259 |
+
else:
|
260 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
261 |
+
except Exception as e:
|
262 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
263 |
+
return response_str
|
vlm_pipeline_old.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import numpy as np
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from openai import OpenAI
|
11 |
+
from transformers import (LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration)
|
12 |
+
from qwen_vl_utils import process_vision_info
|
13 |
+
|
14 |
+
from app.gpt4_o.instructions import (
|
15 |
+
create_editing_category_messages_gpt4o,
|
16 |
+
create_ori_object_messages_gpt4o,
|
17 |
+
create_add_object_messages_gpt4o,
|
18 |
+
create_apply_editing_messages_gpt4o)
|
19 |
+
|
20 |
+
from app.llava.instructions import (
|
21 |
+
create_editing_category_messages_llava,
|
22 |
+
create_ori_object_messages_llava,
|
23 |
+
create_add_object_messages_llava,
|
24 |
+
create_apply_editing_messages_llava)
|
25 |
+
|
26 |
+
from app.qwen2.instructions import (
|
27 |
+
create_editing_category_messages_qwen2,
|
28 |
+
create_ori_object_messages_qwen2,
|
29 |
+
create_add_object_messages_qwen2,
|
30 |
+
create_apply_editing_messages_qwen2)
|
31 |
+
|
32 |
+
from app.utils.utils import run_grounded_sam
|
33 |
+
|
34 |
+
|
35 |
+
def encode_image(img):
|
36 |
+
img = Image.fromarray(img.astype('uint8'))
|
37 |
+
buffered = BytesIO()
|
38 |
+
img.save(buffered, format="PNG")
|
39 |
+
img_bytes = buffered.getvalue()
|
40 |
+
return base64.b64encode(img_bytes).decode('utf-8')
|
41 |
+
|
42 |
+
|
43 |
+
def run_gpt4o_vl_inference(vlm_model,
|
44 |
+
messages):
|
45 |
+
response = vlm_model.chat.completions.create(
|
46 |
+
model="gpt-4o-2024-08-06",
|
47 |
+
messages=messages
|
48 |
+
)
|
49 |
+
response_str = response.choices[0].message.content
|
50 |
+
return response_str
|
51 |
+
|
52 |
+
def run_llava_next_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
|
53 |
+
prompt = vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
|
54 |
+
inputs = vlm_processor(images=image, text=prompt, return_tensors="pt").to(device)
|
55 |
+
output = vlm_model.generate(**inputs, max_new_tokens=200)
|
56 |
+
generated_ids_trimmed = [
|
57 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
|
58 |
+
]
|
59 |
+
response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
|
60 |
+
|
61 |
+
return response_str
|
62 |
+
|
63 |
+
def run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device="cuda"):
|
64 |
+
text = vlm_processor.apply_chat_template(
|
65 |
+
messages, tokenize=False, add_generation_prompt=True
|
66 |
+
)
|
67 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
68 |
+
inputs = vlm_processor(
|
69 |
+
text=[text],
|
70 |
+
images=image_inputs,
|
71 |
+
videos=video_inputs,
|
72 |
+
padding=True,
|
73 |
+
return_tensors="pt",
|
74 |
+
)
|
75 |
+
inputs = inputs.to(device)
|
76 |
+
generated_ids = vlm_model.generate(**inputs, max_new_tokens=128)
|
77 |
+
generated_ids_trimmed = [
|
78 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
79 |
+
]
|
80 |
+
response_str = vlm_processor.decode(generated_ids_trimmed[0], skip_special_tokens=True)
|
81 |
+
return response_str
|
82 |
+
|
83 |
+
|
84 |
+
### response editing type
|
85 |
+
def vlm_response_editing_type(vlm_processor,
|
86 |
+
vlm_model,
|
87 |
+
image,
|
88 |
+
editing_prompt,
|
89 |
+
device):
|
90 |
+
|
91 |
+
if isinstance(vlm_model, OpenAI):
|
92 |
+
messages = create_editing_category_messages_gpt4o(editing_prompt)
|
93 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
94 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
95 |
+
messages = create_editing_category_messages_llava(editing_prompt)
|
96 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device=device)
|
97 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
98 |
+
messages = create_editing_category_messages_qwen2(editing_prompt)
|
99 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
|
100 |
+
|
101 |
+
try:
|
102 |
+
for category_name in ["Addition","Remove","Local","Global","Background"]:
|
103 |
+
if category_name.lower() in response_str.lower():
|
104 |
+
return category_name
|
105 |
+
except Exception as e:
|
106 |
+
raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
|
107 |
+
|
108 |
+
|
109 |
+
### response object to be edited
|
110 |
+
def vlm_response_object_wait_for_edit(vlm_processor,
|
111 |
+
vlm_model,
|
112 |
+
image,
|
113 |
+
category,
|
114 |
+
editing_prompt,
|
115 |
+
device):
|
116 |
+
if category in ["Background", "Global", "Addition"]:
|
117 |
+
edit_object = "nan"
|
118 |
+
return edit_object
|
119 |
+
|
120 |
+
if isinstance(vlm_model, OpenAI):
|
121 |
+
messages = create_ori_object_messages_gpt4o(editing_prompt)
|
122 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
123 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
124 |
+
messages = create_ori_object_messages_llava(editing_prompt)
|
125 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image , device)
|
126 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
127 |
+
messages = create_ori_object_messages_qwen2(editing_prompt)
|
128 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
129 |
+
return response_str
|
130 |
+
|
131 |
+
|
132 |
+
### response mask
|
133 |
+
def vlm_response_mask(vlm_processor,
|
134 |
+
vlm_model,
|
135 |
+
category,
|
136 |
+
image,
|
137 |
+
editing_prompt,
|
138 |
+
object_wait_for_edit,
|
139 |
+
sam=None,
|
140 |
+
sam_predictor=None,
|
141 |
+
sam_automask_generator=None,
|
142 |
+
groundingdino_model=None,
|
143 |
+
device=None,
|
144 |
+
):
|
145 |
+
mask = None
|
146 |
+
if editing_prompt is None or len(editing_prompt)==0:
|
147 |
+
raise gr.Error("Please input the editing instruction!")
|
148 |
+
height, width = image.shape[:2]
|
149 |
+
if category=="Addition":
|
150 |
+
try:
|
151 |
+
if isinstance(vlm_model, OpenAI):
|
152 |
+
base64_image = encode_image(image)
|
153 |
+
messages = create_add_object_messages_gpt4o(editing_prompt, base64_image, height=height, width=width)
|
154 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
155 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
156 |
+
messages = create_add_object_messages_llava(editing_prompt, height=height, width=width)
|
157 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
158 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
159 |
+
base64_image = encode_image(image)
|
160 |
+
messages = create_add_object_messages_qwen2(editing_prompt, base64_image, height=height, width=width)
|
161 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
162 |
+
pattern = r'\[\d{1,3}(?:,\s*\d{1,3}){3}\]'
|
163 |
+
box = re.findall(pattern, response_str)
|
164 |
+
box = box[0][1:-1].split(",")
|
165 |
+
for i in range(len(box)):
|
166 |
+
box[i] = int(box[i])
|
167 |
+
cus_mask = np.zeros((height, width))
|
168 |
+
cus_mask[box[1]: box[1]+box[3], box[0]: box[0]+box[2]]=255
|
169 |
+
mask = cus_mask
|
170 |
+
except:
|
171 |
+
raise gr.Error("Please set the mask manually, currently the VLM cannot output the mask!")
|
172 |
+
|
173 |
+
elif category=="Background":
|
174 |
+
labels = "background"
|
175 |
+
elif category=="Global":
|
176 |
+
mask = 255 * np.zeros((height, width))
|
177 |
+
else:
|
178 |
+
labels = object_wait_for_edit
|
179 |
+
|
180 |
+
if mask is None:
|
181 |
+
for thresh in [0.3,0.25,0.2,0.15,0.1,0.05,0]:
|
182 |
+
try:
|
183 |
+
detections = run_grounded_sam(
|
184 |
+
input_image={"image":Image.fromarray(image.astype('uint8')),
|
185 |
+
"mask":None},
|
186 |
+
text_prompt=labels,
|
187 |
+
task_type="seg",
|
188 |
+
box_threshold=thresh,
|
189 |
+
text_threshold=0.25,
|
190 |
+
iou_threshold=0.5,
|
191 |
+
scribble_mode="split",
|
192 |
+
sam=sam,
|
193 |
+
sam_predictor=sam_predictor,
|
194 |
+
sam_automask_generator=sam_automask_generator,
|
195 |
+
groundingdino_model=groundingdino_model,
|
196 |
+
device=device,
|
197 |
+
)
|
198 |
+
mask = np.array(detections[0,0,...].cpu()) * 255
|
199 |
+
break
|
200 |
+
except:
|
201 |
+
print(f"wrong in threshhold: {thresh}, continue")
|
202 |
+
continue
|
203 |
+
return mask
|
204 |
+
|
205 |
+
|
206 |
+
def vlm_response_prompt_after_apply_instruction(vlm_processor,
|
207 |
+
vlm_model,
|
208 |
+
image,
|
209 |
+
editing_prompt,
|
210 |
+
device):
|
211 |
+
|
212 |
+
try:
|
213 |
+
if isinstance(vlm_model, OpenAI):
|
214 |
+
base64_image = encode_image(image)
|
215 |
+
messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
|
216 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
217 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
218 |
+
messages = create_apply_editing_messages_llava(editing_prompt)
|
219 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
220 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
221 |
+
base64_image = encode_image(image)
|
222 |
+
messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
|
223 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
224 |
+
else:
|
225 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
226 |
+
except Exception as e:
|
227 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
228 |
+
return response_str
|
vlm_template.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from openai import OpenAI
|
5 |
+
from transformers import (
|
6 |
+
LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
7 |
+
Qwen2VLForConditionalGeneration, Qwen2VLProcessor
|
8 |
+
)
|
9 |
+
## init device
|
10 |
+
device = "cpu"
|
11 |
+
torch_dtype = torch.float16
|
12 |
+
|
13 |
+
|
14 |
+
vlms_list = [
|
15 |
+
# {
|
16 |
+
# "type": "llava-next",
|
17 |
+
# "name": "llava-v1.6-mistral-7b-hf",
|
18 |
+
# "local_path": "models/vlms/llava-v1.6-mistral-7b-hf",
|
19 |
+
# "processor": LlavaNextProcessor.from_pretrained(
|
20 |
+
# "models/vlms/llava-v1.6-mistral-7b-hf"
|
21 |
+
# ) if os.path.exists("models/vlms/llava-v1.6-mistral-7b-hf") else LlavaNextProcessor.from_pretrained(
|
22 |
+
# "llava-hf/llava-v1.6-mistral-7b-hf"
|
23 |
+
# ),
|
24 |
+
# "model": LlavaNextForConditionalGeneration.from_pretrained(
|
25 |
+
# "models/vlms/llava-v1.6-mistral-7b-hf", torch_dtype=torch_dtype, device_map=device
|
26 |
+
# ).to("cpu") if os.path.exists("models/vlms/llava-v1.6-mistral-7b-hf") else
|
27 |
+
# LlavaNextForConditionalGeneration.from_pretrained(
|
28 |
+
# "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch_dtype, device_map=device
|
29 |
+
# ).to("cpu"),
|
30 |
+
# },
|
31 |
+
{
|
32 |
+
"type": "llava-next",
|
33 |
+
"name": "llama3-llava-next-8b-hf (Preload)",
|
34 |
+
"local_path": "models/vlms/llama3-llava-next-8b-hf",
|
35 |
+
"processor": LlavaNextProcessor.from_pretrained(
|
36 |
+
"models/vlms/llama3-llava-next-8b-hf"
|
37 |
+
) if os.path.exists("models/vlms/llama3-llava-next-8b-hf") else LlavaNextProcessor.from_pretrained(
|
38 |
+
"llava-hf/llama3-llava-next-8b-hf"
|
39 |
+
),
|
40 |
+
"model": LlavaNextForConditionalGeneration.from_pretrained(
|
41 |
+
"models/vlms/llama3-llava-next-8b-hf", torch_dtype=torch_dtype, device_map=device
|
42 |
+
).to("cpu") if os.path.exists("models/vlms/llama3-llava-next-8b-hf") else
|
43 |
+
LlavaNextForConditionalGeneration.from_pretrained(
|
44 |
+
"llava-hf/llama3-llava-next-8b-hf", torch_dtype=torch_dtype, device_map=device
|
45 |
+
).to("cpu"),
|
46 |
+
},
|
47 |
+
# {
|
48 |
+
# "type": "llava-next",
|
49 |
+
# "name": "llava-v1.6-vicuna-13b-hf",
|
50 |
+
# "local_path": "models/vlms/llava-v1.6-vicuna-13b-hf",
|
51 |
+
# "processor": LlavaNextProcessor.from_pretrained(
|
52 |
+
# "models/vlms/llava-v1.6-vicuna-13b-hf"
|
53 |
+
# ) if os.path.exists("models/vlms/llava-v1.6-vicuna-13b-hf") else LlavaNextProcessor.from_pretrained(
|
54 |
+
# "llava-hf/llava-v1.6-vicuna-13b-hf"
|
55 |
+
# ),
|
56 |
+
# "model": LlavaNextForConditionalGeneration.from_pretrained(
|
57 |
+
# "models/vlms/llava-v1.6-vicuna-13b-hf", torch_dtype=torch_dtype, device_map=device
|
58 |
+
# ).to("cpu") if os.path.exists("models/vlms/llava-v1.6-vicuna-13b-hf") else
|
59 |
+
# LlavaNextForConditionalGeneration.from_pretrained(
|
60 |
+
# "llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype=torch_dtype, device_map=device
|
61 |
+
# ).to("cpu"),
|
62 |
+
# },
|
63 |
+
# {
|
64 |
+
# "type": "llava-next",
|
65 |
+
# "name": "llava-v1.6-34b-hf",
|
66 |
+
# "local_path": "models/vlms/llava-v1.6-34b-hf",
|
67 |
+
# "processor": LlavaNextProcessor.from_pretrained(
|
68 |
+
# "models/vlms/llava-v1.6-34b-hf"
|
69 |
+
# ) if os.path.exists("models/vlms/llava-v1.6-34b-hf") else LlavaNextProcessor.from_pretrained(
|
70 |
+
# "llava-hf/llava-v1.6-34b-hf"
|
71 |
+
# ),
|
72 |
+
# "model": LlavaNextForConditionalGeneration.from_pretrained(
|
73 |
+
# "models/vlms/llava-v1.6-34b-hf", torch_dtype=torch_dtype, device_map=device
|
74 |
+
# ).to("cpu") if os.path.exists("models/vlms/llava-v1.6-34b-hf") else
|
75 |
+
# LlavaNextForConditionalGeneration.from_pretrained(
|
76 |
+
# "llava-hf/llava-v1.6-34b-hf", torch_dtype=torch_dtype, device_map=device
|
77 |
+
# ).to("cpu"),
|
78 |
+
# },
|
79 |
+
# {
|
80 |
+
# "type": "qwen2-vl",
|
81 |
+
# "name": "Qwen2-VL-2B-Instruct",
|
82 |
+
# "local_path": "models/vlms/Qwen2-VL-2B-Instruct",
|
83 |
+
# "processor": Qwen2VLProcessor.from_pretrained(
|
84 |
+
# "models/vlms/Qwen2-VL-2B-Instruct"
|
85 |
+
# ) if os.path.exists("models/vlms/Qwen2-VL-2B-Instruct") else Qwen2VLProcessor.from_pretrained(
|
86 |
+
# "Qwen/Qwen2-VL-2B-Instruct"
|
87 |
+
# ),
|
88 |
+
# "model": Qwen2VLForConditionalGeneration.from_pretrained(
|
89 |
+
# "models/vlms/Qwen2-VL-2B-Instruct", torch_dtype=torch_dtype, device_map=device
|
90 |
+
# ).to("cpu") if os.path.exists("models/vlms/Qwen2-VL-2B-Instruct") else
|
91 |
+
# Qwen2VLForConditionalGeneration.from_pretrained(
|
92 |
+
# "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch_dtype, device_map=device
|
93 |
+
# ).to("cpu"),
|
94 |
+
# },
|
95 |
+
{
|
96 |
+
"type": "qwen2-vl",
|
97 |
+
"name": "Qwen2-VL-7B-Instruct (Default)",
|
98 |
+
"local_path": "models/vlms/Qwen2-VL-7B-Instruct",
|
99 |
+
"processor": Qwen2VLProcessor.from_pretrained(
|
100 |
+
"models/vlms/Qwen2-VL-7B-Instruct"
|
101 |
+
) if os.path.exists("models/vlms/Qwen2-VL-7B-Instruct") else Qwen2VLProcessor.from_pretrained(
|
102 |
+
"Qwen/Qwen2-VL-7B-Instruct"
|
103 |
+
),
|
104 |
+
"model": Qwen2VLForConditionalGeneration.from_pretrained(
|
105 |
+
"models/vlms/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
|
106 |
+
).to("cpu") if os.path.exists("models/vlms/Qwen2-VL-7B-Instruct") else
|
107 |
+
Qwen2VLForConditionalGeneration.from_pretrained(
|
108 |
+
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
|
109 |
+
).to("cpu"),
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"type": "openai",
|
113 |
+
"name": "GPT4-o (Highly Recommended)",
|
114 |
+
"local_path": "",
|
115 |
+
"processor": "",
|
116 |
+
"model": ""
|
117 |
+
},
|
118 |
+
]
|
119 |
+
|
120 |
+
vlms_template = {k["name"]: (k["type"], k["local_path"], k["processor"], k["model"]) for k in vlms_list}
|