jingyangcarl commited on
Commit
ce4c1d3
·
1 Parent(s): 1b2ea86
Files changed (12) hide show
  1. README.md +12 -1
  2. app.py +55 -139
  3. app_canny.py +83 -0
  4. app_matnet.py +83 -0
  5. app_texnet.py +83 -0
  6. cv_utils.py +17 -0
  7. depth_estimator.py +25 -0
  8. image_segmentor.py +33 -0
  9. model.py +670 -0
  10. preprocessor.py +88 -0
  11. settings.py +19 -0
  12. utils.py +9 -0
README.md CHANGED
@@ -10,4 +10,15 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ ## setup locally
17
+ conda create -n matgen python=3.11
18
+ conda activate matgen
19
+ pip install diffusers["torch"] transformers accelerate xformers
20
+ pip install gradio
21
+ pip install controlnet-aux
22
+
23
+ ## local authen
24
+ huggingface-cli login
app.py CHANGED
@@ -1,156 +1,72 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
40
 
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
 
51
- return image, seed
 
52
 
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
70
 
 
71
  with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
 
 
 
 
 
137
  gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
155
-
156
- print()
 
1
+ #!/usr/bin/env python
 
 
2
 
3
+ import gradio as gr
 
4
  import torch
5
 
6
+ from app_canny import create_demo as create_demo_canny
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from model import Model
9
+ from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
10
 
11
+ DESCRIPTION = "# Material Authoring Demo v0.1"
 
 
 
 
 
 
 
 
12
 
13
+ if not torch.cuda.is_available():
14
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
15
 
16
+ model = Model(base_model_id=DEFAULT_MODEL_ID, task_name="Canny")
17
 
18
+ with gr.Blocks() as demo:
19
+ gr.Markdown(DESCRIPTION)
20
+ gr.DuplicateButton(
21
+ value="Duplicate Space for private use",
22
+ elem_id="duplicate-button",
23
+ visible=SHOW_DUPLICATE_BUTTON,
24
+ )
 
 
 
 
 
25
 
26
+ with gr.Tabs():
27
+ with gr.Tab("Canny"):
28
+ create_demo_canny(model.process_canny)
29
+ with gr.Tab("Texnet"):
30
+ create_demo_canny(model.process_canny)
31
+ with gr.Tab("Matnet"):
32
+ create_demo_canny(model.process_canny)
33
 
34
+ with gr.Accordion(label="Base model", open=False):
35
  with gr.Row():
36
+ with gr.Column(scale=5):
37
+ current_base_model = gr.Text(label="Current base model")
38
+ with gr.Column(scale=1):
39
+ check_base_model_button = gr.Button("Check current base model")
40
+ with gr.Row():
41
+ with gr.Column(scale=5):
42
+ new_base_model_id = gr.Text(
43
+ label="New base model",
44
+ max_lines=1,
45
+ placeholder="stable-diffusion-v1-5/stable-diffusion-v1-5",
46
+ info="The base model must be compatible with Stable Diffusion v1.5.",
47
+ interactive=ALLOW_CHANGING_BASE_MODEL,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
+ with gr.Column(scale=1):
50
+ change_base_model_button = gr.Button("Change base model", interactive=ALLOW_CHANGING_BASE_MODEL)
51
+ if not ALLOW_CHANGING_BASE_MODEL:
52
+ gr.Markdown(
53
+ """The base model is not allowed to be changed in this Space so as not to slow down the demo, but it can be changed if you duplicate the Space."""
54
+ )
55
 
56
+ check_base_model_button.click(
57
+ fn=lambda: model.base_model_id,
58
+ outputs=current_base_model,
59
+ queue=False,
60
+ api_name="check_base_model",
61
+ )
62
  gr.on(
63
+ triggers=[new_base_model_id.submit, change_base_model_button.click],
64
+ fn=model.set_base_model,
65
+ inputs=new_base_model_id,
66
+ outputs=current_base_model,
67
+ api_name=False,
68
+ concurrency_id="main",
 
 
 
 
 
 
 
69
  )
70
 
71
  if __name__ == "__main__":
72
+ demo.queue(max_size=20).launch()
 
 
app_canny.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from settings import (
6
+ DEFAULT_IMAGE_RESOLUTION,
7
+ DEFAULT_NUM_IMAGES,
8
+ MAX_IMAGE_RESOLUTION,
9
+ MAX_NUM_IMAGES,
10
+ MAX_SEED,
11
+ )
12
+ from utils import randomize_seed_fn
13
+
14
+
15
+ def create_demo(process):
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ image = gr.Image()
20
+ prompt = gr.Textbox(label="Prompt", submit_btn=True)
21
+ with gr.Accordion("Advanced options", open=False):
22
+ num_samples = gr.Slider(
23
+ label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
24
+ )
25
+ image_resolution = gr.Slider(
26
+ label="Image resolution",
27
+ minimum=256,
28
+ maximum=MAX_IMAGE_RESOLUTION,
29
+ value=DEFAULT_IMAGE_RESOLUTION,
30
+ step=256,
31
+ )
32
+ canny_low_threshold = gr.Slider(
33
+ label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
34
+ )
35
+ canny_high_threshold = gr.Slider(
36
+ label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
37
+ )
38
+ num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
39
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
40
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
41
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
42
+ a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
43
+ n_prompt = gr.Textbox(
44
+ label="Negative prompt",
45
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
46
+ )
47
+ with gr.Column():
48
+ result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
49
+ inputs = [
50
+ image,
51
+ prompt,
52
+ a_prompt,
53
+ n_prompt,
54
+ num_samples,
55
+ image_resolution,
56
+ num_steps,
57
+ guidance_scale,
58
+ seed,
59
+ canny_low_threshold,
60
+ canny_high_threshold,
61
+ ]
62
+ prompt.submit(
63
+ fn=randomize_seed_fn,
64
+ inputs=[seed, randomize_seed],
65
+ outputs=seed,
66
+ queue=False,
67
+ api_name=False,
68
+ ).then(
69
+ fn=process,
70
+ inputs=inputs,
71
+ outputs=result,
72
+ api_name="canny",
73
+ concurrency_id="main",
74
+ )
75
+ return demo
76
+
77
+
78
+ if __name__ == "__main__":
79
+ from model import Model
80
+
81
+ model = Model(task_name="Canny")
82
+ demo = create_demo(model.process_canny)
83
+ demo.queue().launch()
app_matnet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from settings import (
6
+ DEFAULT_IMAGE_RESOLUTION,
7
+ DEFAULT_NUM_IMAGES,
8
+ MAX_IMAGE_RESOLUTION,
9
+ MAX_NUM_IMAGES,
10
+ MAX_SEED,
11
+ )
12
+ from utils import randomize_seed_fn
13
+
14
+
15
+ def create_demo(process):
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ image = gr.Image()
20
+ prompt = gr.Textbox(label="Prompt", submit_btn=True)
21
+ with gr.Accordion("Advanced options", open=False):
22
+ num_samples = gr.Slider(
23
+ label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
24
+ )
25
+ image_resolution = gr.Slider(
26
+ label="Image resolution",
27
+ minimum=256,
28
+ maximum=MAX_IMAGE_RESOLUTION,
29
+ value=DEFAULT_IMAGE_RESOLUTION,
30
+ step=256,
31
+ )
32
+ canny_low_threshold = gr.Slider(
33
+ label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
34
+ )
35
+ canny_high_threshold = gr.Slider(
36
+ label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
37
+ )
38
+ num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
39
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
40
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
41
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
42
+ a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
43
+ n_prompt = gr.Textbox(
44
+ label="Negative prompt",
45
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
46
+ )
47
+ with gr.Column():
48
+ result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
49
+ inputs = [
50
+ image,
51
+ prompt,
52
+ a_prompt,
53
+ n_prompt,
54
+ num_samples,
55
+ image_resolution,
56
+ num_steps,
57
+ guidance_scale,
58
+ seed,
59
+ canny_low_threshold,
60
+ canny_high_threshold,
61
+ ]
62
+ prompt.submit(
63
+ fn=randomize_seed_fn,
64
+ inputs=[seed, randomize_seed],
65
+ outputs=seed,
66
+ queue=False,
67
+ api_name=False,
68
+ ).then(
69
+ fn=process,
70
+ inputs=inputs,
71
+ outputs=result,
72
+ api_name="canny",
73
+ concurrency_id="main",
74
+ )
75
+ return demo
76
+
77
+
78
+ if __name__ == "__main__":
79
+ from model import Model
80
+
81
+ model = Model(task_name="Canny")
82
+ demo = create_demo(model.process_canny)
83
+ demo.queue().launch()
app_texnet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import gradio as gr
4
+
5
+ from settings import (
6
+ DEFAULT_IMAGE_RESOLUTION,
7
+ DEFAULT_NUM_IMAGES,
8
+ MAX_IMAGE_RESOLUTION,
9
+ MAX_NUM_IMAGES,
10
+ MAX_SEED,
11
+ )
12
+ from utils import randomize_seed_fn
13
+
14
+
15
+ def create_demo(process):
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ image = gr.Image()
20
+ prompt = gr.Textbox(label="Prompt", submit_btn=True)
21
+ with gr.Accordion("Advanced options", open=False):
22
+ num_samples = gr.Slider(
23
+ label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
24
+ )
25
+ image_resolution = gr.Slider(
26
+ label="Image resolution",
27
+ minimum=256,
28
+ maximum=MAX_IMAGE_RESOLUTION,
29
+ value=DEFAULT_IMAGE_RESOLUTION,
30
+ step=256,
31
+ )
32
+ canny_low_threshold = gr.Slider(
33
+ label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
34
+ )
35
+ canny_high_threshold = gr.Slider(
36
+ label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
37
+ )
38
+ num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
39
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
40
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
41
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
42
+ a_prompt = gr.Textbox(label="Additional prompt", value="best quality, extremely detailed")
43
+ n_prompt = gr.Textbox(
44
+ label="Negative prompt",
45
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
46
+ )
47
+ with gr.Column():
48
+ result = gr.Gallery(label="Output", show_label=False, columns=2, object_fit="scale-down")
49
+ inputs = [
50
+ image,
51
+ prompt,
52
+ a_prompt,
53
+ n_prompt,
54
+ num_samples,
55
+ image_resolution,
56
+ num_steps,
57
+ guidance_scale,
58
+ seed,
59
+ canny_low_threshold,
60
+ canny_high_threshold,
61
+ ]
62
+ prompt.submit(
63
+ fn=randomize_seed_fn,
64
+ inputs=[seed, randomize_seed],
65
+ outputs=seed,
66
+ queue=False,
67
+ api_name=False,
68
+ ).then(
69
+ fn=process,
70
+ inputs=inputs,
71
+ outputs=result,
72
+ api_name="canny",
73
+ concurrency_id="main",
74
+ )
75
+ return demo
76
+
77
+
78
+ if __name__ == "__main__":
79
+ from model import Model
80
+
81
+ model = Model(task_name="Canny")
82
+ demo = create_demo(model.process_canny)
83
+ demo.queue().launch()
cv_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def resize_image(input_image, resolution, interpolation=None):
6
+ H, W, C = input_image.shape
7
+ H = float(H)
8
+ W = float(W)
9
+ k = float(resolution) / max(H, W)
10
+ H *= k
11
+ W *= k
12
+ H = int(np.round(H / 64.0)) * 64
13
+ W = int(np.round(W / 64.0)) * 64
14
+ if interpolation is None:
15
+ interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
16
+ img = cv2.resize(input_image, (W, H), interpolation=interpolation)
17
+ return img
depth_estimator.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL.Image
3
+ from controlnet_aux.util import HWC3
4
+ from transformers import pipeline
5
+
6
+ from cv_utils import resize_image
7
+
8
+
9
+ class DepthEstimator:
10
+ def __init__(self):
11
+ self.model = pipeline("depth-estimation")
12
+
13
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
14
+ detect_resolution = kwargs.pop("detect_resolution", 512)
15
+ image_resolution = kwargs.pop("image_resolution", 512)
16
+ image = np.array(image)
17
+ image = HWC3(image)
18
+ image = resize_image(image, resolution=detect_resolution)
19
+ image = PIL.Image.fromarray(image)
20
+ image = self.model(image)
21
+ image = image["depth"]
22
+ image = np.array(image)
23
+ image = HWC3(image)
24
+ image = resize_image(image, resolution=image_resolution)
25
+ return PIL.Image.fromarray(image)
image_segmentor.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import PIL.Image
4
+ import torch
5
+ from controlnet_aux.util import HWC3, ade_palette
6
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
7
+
8
+ from cv_utils import resize_image
9
+
10
+
11
+ class ImageSegmentor:
12
+ def __init__(self):
13
+ self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
14
+ self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
15
+
16
+ @torch.inference_mode()
17
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
18
+ detect_resolution = kwargs.pop("detect_resolution", 512)
19
+ image_resolution = kwargs.pop("image_resolution", 512)
20
+ image = HWC3(image)
21
+ image = resize_image(image, resolution=detect_resolution)
22
+ image = PIL.Image.fromarray(image)
23
+
24
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
25
+ outputs = self.image_segmentor(pixel_values)
26
+ seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
27
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
28
+ for label, color in enumerate(ade_palette()):
29
+ color_seg[seg == label, :] = color
30
+ color_seg = color_seg.astype(np.uint8)
31
+
32
+ color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
33
+ return PIL.Image.fromarray(color_seg)
model.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ from controlnet_aux.util import HWC3
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ DiffusionPipeline,
10
+ StableDiffusionControlNetPipeline,
11
+ UniPCMultistepScheduler,
12
+ )
13
+
14
+ from cv_utils import resize_image
15
+ from preprocessor import Preprocessor
16
+ from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
17
+
18
+ CONTROLNET_MODEL_IDS = {
19
+ "Openpose": "lllyasviel/control_v11p_sd15_openpose",
20
+ "Canny": "lllyasviel/control_v11p_sd15_canny",
21
+ "MLSD": "lllyasviel/control_v11p_sd15_mlsd",
22
+ "scribble": "lllyasviel/control_v11p_sd15_scribble",
23
+ "softedge": "lllyasviel/control_v11p_sd15_softedge",
24
+ "segmentation": "lllyasviel/control_v11p_sd15_seg",
25
+ "depth": "lllyasviel/control_v11f1p_sd15_depth",
26
+ "NormalBae": "lllyasviel/control_v11p_sd15_normalbae",
27
+ "lineart": "lllyasviel/control_v11p_sd15_lineart",
28
+ "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
29
+ "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
30
+ "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
31
+ "inpaint": "lllyasviel/control_v11e_sd15_inpaint",
32
+ }
33
+
34
+
35
+ def download_all_controlnet_weights() -> None:
36
+ for model_id in CONTROLNET_MODEL_IDS.values():
37
+ ControlNetModel.from_pretrained(model_id)
38
+
39
+
40
+ class Model:
41
+ def __init__(
42
+ self, base_model_id: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", task_name: str = "Canny"
43
+ ) -> None:
44
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
+ self.base_model_id = ""
46
+ self.task_name = ""
47
+ self.pipe = self.load_pipe(base_model_id, task_name)
48
+ self.preprocessor = Preprocessor()
49
+
50
+ def load_pipe(self, base_model_id: str, task_name: str) -> DiffusionPipeline:
51
+ if (
52
+ base_model_id == self.base_model_id
53
+ and task_name == self.task_name
54
+ and hasattr(self, "pipe")
55
+ and self.pipe is not None
56
+ ):
57
+ return self.pipe
58
+ model_id = CONTROLNET_MODEL_IDS[task_name]
59
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
60
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
61
+ base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
62
+ )
63
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
64
+ if self.device.type == "cuda":
65
+ pipe.enable_xformers_memory_efficient_attention()
66
+ pipe.to(self.device)
67
+ torch.cuda.empty_cache()
68
+ gc.collect()
69
+ self.base_model_id = base_model_id
70
+ self.task_name = task_name
71
+ return pipe
72
+
73
+ def set_base_model(self, base_model_id: str) -> str:
74
+ if not base_model_id or base_model_id == self.base_model_id:
75
+ return self.base_model_id
76
+ del self.pipe
77
+ torch.cuda.empty_cache()
78
+ gc.collect()
79
+ try:
80
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
81
+ except Exception: # noqa: BLE001
82
+ self.pipe = self.load_pipe(self.base_model_id, self.task_name)
83
+ return self.base_model_id
84
+
85
+ def load_controlnet_weight(self, task_name: str) -> None:
86
+ if task_name == self.task_name:
87
+ return
88
+ if self.pipe is not None and hasattr(self.pipe, "controlnet"):
89
+ del self.pipe.controlnet
90
+ torch.cuda.empty_cache()
91
+ gc.collect()
92
+ model_id = CONTROLNET_MODEL_IDS[task_name]
93
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
94
+ controlnet.to(self.device)
95
+ torch.cuda.empty_cache()
96
+ gc.collect()
97
+ self.pipe.controlnet = controlnet
98
+ self.task_name = task_name
99
+
100
+ def get_prompt(self, prompt: str, additional_prompt: str) -> str:
101
+ return additional_prompt if not prompt else f"{prompt}, {additional_prompt}"
102
+
103
+ @torch.autocast("cuda")
104
+ def run_pipe(
105
+ self,
106
+ prompt: str,
107
+ negative_prompt: str,
108
+ control_image: PIL.Image.Image,
109
+ num_images: int,
110
+ num_steps: int,
111
+ guidance_scale: float,
112
+ seed: int,
113
+ ) -> list[PIL.Image.Image]:
114
+ generator = torch.Generator().manual_seed(seed)
115
+ return self.pipe(
116
+ prompt=prompt,
117
+ negative_prompt=negative_prompt,
118
+ guidance_scale=guidance_scale,
119
+ num_images_per_prompt=num_images,
120
+ num_inference_steps=num_steps,
121
+ generator=generator,
122
+ image=control_image,
123
+ ).images
124
+
125
+ @torch.inference_mode()
126
+ def process_canny(
127
+ self,
128
+ image: np.ndarray,
129
+ prompt: str,
130
+ additional_prompt: str,
131
+ negative_prompt: str,
132
+ num_images: int,
133
+ image_resolution: int,
134
+ num_steps: int,
135
+ guidance_scale: float,
136
+ seed: int,
137
+ low_threshold: int,
138
+ high_threshold: int,
139
+ ) -> list[PIL.Image.Image]:
140
+ if image is None:
141
+ raise ValueError
142
+ if image_resolution > MAX_IMAGE_RESOLUTION:
143
+ raise ValueError
144
+ if num_images > MAX_NUM_IMAGES:
145
+ raise ValueError
146
+
147
+ self.preprocessor.load("Canny")
148
+ control_image = self.preprocessor(
149
+ image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=image_resolution
150
+ )
151
+
152
+ self.load_controlnet_weight("Canny")
153
+ results = self.run_pipe(
154
+ prompt=self.get_prompt(prompt, additional_prompt),
155
+ negative_prompt=negative_prompt,
156
+ control_image=control_image,
157
+ num_images=num_images,
158
+ num_steps=num_steps,
159
+ guidance_scale=guidance_scale,
160
+ seed=seed,
161
+ )
162
+ return [control_image, *results]
163
+
164
+ @torch.inference_mode()
165
+ def process_mlsd(
166
+ self,
167
+ image: np.ndarray,
168
+ prompt: str,
169
+ additional_prompt: str,
170
+ negative_prompt: str,
171
+ num_images: int,
172
+ image_resolution: int,
173
+ preprocess_resolution: int,
174
+ num_steps: int,
175
+ guidance_scale: float,
176
+ seed: int,
177
+ value_threshold: float,
178
+ distance_threshold: float,
179
+ ) -> list[PIL.Image.Image]:
180
+ if image is None:
181
+ raise ValueError
182
+ if image_resolution > MAX_IMAGE_RESOLUTION:
183
+ raise ValueError
184
+ if num_images > MAX_NUM_IMAGES:
185
+ raise ValueError
186
+
187
+ self.preprocessor.load("MLSD")
188
+ control_image = self.preprocessor(
189
+ image=image,
190
+ image_resolution=image_resolution,
191
+ detect_resolution=preprocess_resolution,
192
+ thr_v=value_threshold,
193
+ thr_d=distance_threshold,
194
+ )
195
+ self.load_controlnet_weight("MLSD")
196
+ results = self.run_pipe(
197
+ prompt=self.get_prompt(prompt, additional_prompt),
198
+ negative_prompt=negative_prompt,
199
+ control_image=control_image,
200
+ num_images=num_images,
201
+ num_steps=num_steps,
202
+ guidance_scale=guidance_scale,
203
+ seed=seed,
204
+ )
205
+ return [control_image, *results]
206
+
207
+ @torch.inference_mode()
208
+ def process_scribble(
209
+ self,
210
+ image: np.ndarray,
211
+ prompt: str,
212
+ additional_prompt: str,
213
+ negative_prompt: str,
214
+ num_images: int,
215
+ image_resolution: int,
216
+ preprocess_resolution: int,
217
+ num_steps: int,
218
+ guidance_scale: float,
219
+ seed: int,
220
+ preprocessor_name: str,
221
+ ) -> list[PIL.Image.Image]:
222
+ if image is None:
223
+ raise ValueError
224
+ if image_resolution > MAX_IMAGE_RESOLUTION:
225
+ raise ValueError
226
+ if num_images > MAX_NUM_IMAGES:
227
+ raise ValueError
228
+
229
+ if preprocessor_name == "None":
230
+ image = HWC3(image)
231
+ image = resize_image(image, resolution=image_resolution)
232
+ control_image = PIL.Image.fromarray(image)
233
+ elif preprocessor_name == "HED":
234
+ self.preprocessor.load(preprocessor_name)
235
+ control_image = self.preprocessor(
236
+ image=image,
237
+ image_resolution=image_resolution,
238
+ detect_resolution=preprocess_resolution,
239
+ scribble=False,
240
+ )
241
+ elif preprocessor_name == "PidiNet":
242
+ self.preprocessor.load(preprocessor_name)
243
+ control_image = self.preprocessor(
244
+ image=image,
245
+ image_resolution=image_resolution,
246
+ detect_resolution=preprocess_resolution,
247
+ safe=False,
248
+ )
249
+ self.load_controlnet_weight("scribble")
250
+ results = self.run_pipe(
251
+ prompt=self.get_prompt(prompt, additional_prompt),
252
+ negative_prompt=negative_prompt,
253
+ control_image=control_image,
254
+ num_images=num_images,
255
+ num_steps=num_steps,
256
+ guidance_scale=guidance_scale,
257
+ seed=seed,
258
+ )
259
+ return [control_image, *results]
260
+
261
+ @torch.inference_mode()
262
+ def process_scribble_interactive(
263
+ self,
264
+ image_and_mask: dict[str, np.ndarray | list[np.ndarray]] | None,
265
+ prompt: str,
266
+ additional_prompt: str,
267
+ negative_prompt: str,
268
+ num_images: int,
269
+ image_resolution: int,
270
+ num_steps: int,
271
+ guidance_scale: float,
272
+ seed: int,
273
+ ) -> list[PIL.Image.Image]:
274
+ if image_and_mask is None:
275
+ raise ValueError
276
+ if image_resolution > MAX_IMAGE_RESOLUTION:
277
+ raise ValueError
278
+ if num_images > MAX_NUM_IMAGES:
279
+ raise ValueError
280
+
281
+ image = 255 - image_and_mask["composite"] # type: ignore
282
+ image = HWC3(image)
283
+ image = resize_image(image, resolution=image_resolution)
284
+ control_image = PIL.Image.fromarray(image)
285
+
286
+ self.load_controlnet_weight("scribble")
287
+ results = self.run_pipe(
288
+ prompt=self.get_prompt(prompt, additional_prompt),
289
+ negative_prompt=negative_prompt,
290
+ control_image=control_image,
291
+ num_images=num_images,
292
+ num_steps=num_steps,
293
+ guidance_scale=guidance_scale,
294
+ seed=seed,
295
+ )
296
+ return [control_image, *results]
297
+
298
+ @torch.inference_mode()
299
+ def process_softedge(
300
+ self,
301
+ image: np.ndarray,
302
+ prompt: str,
303
+ additional_prompt: str,
304
+ negative_prompt: str,
305
+ num_images: int,
306
+ image_resolution: int,
307
+ preprocess_resolution: int,
308
+ num_steps: int,
309
+ guidance_scale: float,
310
+ seed: int,
311
+ preprocessor_name: str,
312
+ ) -> list[PIL.Image.Image]:
313
+ if image is None:
314
+ raise ValueError
315
+ if image_resolution > MAX_IMAGE_RESOLUTION:
316
+ raise ValueError
317
+ if num_images > MAX_NUM_IMAGES:
318
+ raise ValueError
319
+
320
+ if preprocessor_name == "None":
321
+ image = HWC3(image)
322
+ image = resize_image(image, resolution=image_resolution)
323
+ control_image = PIL.Image.fromarray(image)
324
+ elif preprocessor_name in ["HED", "HED safe"]:
325
+ safe = "safe" in preprocessor_name
326
+ self.preprocessor.load("HED")
327
+ control_image = self.preprocessor(
328
+ image=image,
329
+ image_resolution=image_resolution,
330
+ detect_resolution=preprocess_resolution,
331
+ scribble=safe,
332
+ )
333
+ elif preprocessor_name in ["PidiNet", "PidiNet safe"]:
334
+ safe = "safe" in preprocessor_name
335
+ self.preprocessor.load("PidiNet")
336
+ control_image = self.preprocessor(
337
+ image=image,
338
+ image_resolution=image_resolution,
339
+ detect_resolution=preprocess_resolution,
340
+ safe=safe,
341
+ )
342
+ else:
343
+ raise ValueError
344
+ self.load_controlnet_weight("softedge")
345
+ results = self.run_pipe(
346
+ prompt=self.get_prompt(prompt, additional_prompt),
347
+ negative_prompt=negative_prompt,
348
+ control_image=control_image,
349
+ num_images=num_images,
350
+ num_steps=num_steps,
351
+ guidance_scale=guidance_scale,
352
+ seed=seed,
353
+ )
354
+ return [control_image, *results]
355
+
356
+ @torch.inference_mode()
357
+ def process_openpose(
358
+ self,
359
+ image: np.ndarray,
360
+ prompt: str,
361
+ additional_prompt: str,
362
+ negative_prompt: str,
363
+ num_images: int,
364
+ image_resolution: int,
365
+ preprocess_resolution: int,
366
+ num_steps: int,
367
+ guidance_scale: float,
368
+ seed: int,
369
+ preprocessor_name: str,
370
+ ) -> list[PIL.Image.Image]:
371
+ if image is None:
372
+ raise ValueError
373
+ if image_resolution > MAX_IMAGE_RESOLUTION:
374
+ raise ValueError
375
+ if num_images > MAX_NUM_IMAGES:
376
+ raise ValueError
377
+
378
+ if preprocessor_name == "None":
379
+ image = HWC3(image)
380
+ image = resize_image(image, resolution=image_resolution)
381
+ control_image = PIL.Image.fromarray(image)
382
+ else:
383
+ self.preprocessor.load("Openpose")
384
+ control_image = self.preprocessor(
385
+ image=image,
386
+ image_resolution=image_resolution,
387
+ detect_resolution=preprocess_resolution,
388
+ hand_and_face=True,
389
+ )
390
+ self.load_controlnet_weight("Openpose")
391
+ results = self.run_pipe(
392
+ prompt=self.get_prompt(prompt, additional_prompt),
393
+ negative_prompt=negative_prompt,
394
+ control_image=control_image,
395
+ num_images=num_images,
396
+ num_steps=num_steps,
397
+ guidance_scale=guidance_scale,
398
+ seed=seed,
399
+ )
400
+ return [control_image, *results]
401
+
402
+ @torch.inference_mode()
403
+ def process_segmentation(
404
+ self,
405
+ image: np.ndarray,
406
+ prompt: str,
407
+ additional_prompt: str,
408
+ negative_prompt: str,
409
+ num_images: int,
410
+ image_resolution: int,
411
+ preprocess_resolution: int,
412
+ num_steps: int,
413
+ guidance_scale: float,
414
+ seed: int,
415
+ preprocessor_name: str,
416
+ ) -> list[PIL.Image.Image]:
417
+ if image is None:
418
+ raise ValueError
419
+ if image_resolution > MAX_IMAGE_RESOLUTION:
420
+ raise ValueError
421
+ if num_images > MAX_NUM_IMAGES:
422
+ raise ValueError
423
+
424
+ if preprocessor_name == "None":
425
+ image = HWC3(image)
426
+ image = resize_image(image, resolution=image_resolution)
427
+ control_image = PIL.Image.fromarray(image)
428
+ else:
429
+ self.preprocessor.load(preprocessor_name)
430
+ control_image = self.preprocessor(
431
+ image=image,
432
+ image_resolution=image_resolution,
433
+ detect_resolution=preprocess_resolution,
434
+ )
435
+ self.load_controlnet_weight("segmentation")
436
+ results = self.run_pipe(
437
+ prompt=self.get_prompt(prompt, additional_prompt),
438
+ negative_prompt=negative_prompt,
439
+ control_image=control_image,
440
+ num_images=num_images,
441
+ num_steps=num_steps,
442
+ guidance_scale=guidance_scale,
443
+ seed=seed,
444
+ )
445
+ return [control_image, *results]
446
+
447
+ @torch.inference_mode()
448
+ def process_depth(
449
+ self,
450
+ image: np.ndarray,
451
+ prompt: str,
452
+ additional_prompt: str,
453
+ negative_prompt: str,
454
+ num_images: int,
455
+ image_resolution: int,
456
+ preprocess_resolution: int,
457
+ num_steps: int,
458
+ guidance_scale: float,
459
+ seed: int,
460
+ preprocessor_name: str,
461
+ ) -> list[PIL.Image.Image]:
462
+ if image is None:
463
+ raise ValueError
464
+ if image_resolution > MAX_IMAGE_RESOLUTION:
465
+ raise ValueError
466
+ if num_images > MAX_NUM_IMAGES:
467
+ raise ValueError
468
+
469
+ if preprocessor_name == "None":
470
+ image = HWC3(image)
471
+ image = resize_image(image, resolution=image_resolution)
472
+ control_image = PIL.Image.fromarray(image)
473
+ else:
474
+ self.preprocessor.load(preprocessor_name)
475
+ control_image = self.preprocessor(
476
+ image=image,
477
+ image_resolution=image_resolution,
478
+ detect_resolution=preprocess_resolution,
479
+ )
480
+ self.load_controlnet_weight("depth")
481
+ results = self.run_pipe(
482
+ prompt=self.get_prompt(prompt, additional_prompt),
483
+ negative_prompt=negative_prompt,
484
+ control_image=control_image,
485
+ num_images=num_images,
486
+ num_steps=num_steps,
487
+ guidance_scale=guidance_scale,
488
+ seed=seed,
489
+ )
490
+ return [control_image, *results]
491
+
492
+ @torch.inference_mode()
493
+ def process_normal(
494
+ self,
495
+ image: np.ndarray,
496
+ prompt: str,
497
+ additional_prompt: str,
498
+ negative_prompt: str,
499
+ num_images: int,
500
+ image_resolution: int,
501
+ preprocess_resolution: int,
502
+ num_steps: int,
503
+ guidance_scale: float,
504
+ seed: int,
505
+ preprocessor_name: str,
506
+ ) -> list[PIL.Image.Image]:
507
+ if image is None:
508
+ raise ValueError
509
+ if image_resolution > MAX_IMAGE_RESOLUTION:
510
+ raise ValueError
511
+ if num_images > MAX_NUM_IMAGES:
512
+ raise ValueError
513
+
514
+ if preprocessor_name == "None":
515
+ image = HWC3(image)
516
+ image = resize_image(image, resolution=image_resolution)
517
+ control_image = PIL.Image.fromarray(image)
518
+ else:
519
+ self.preprocessor.load("NormalBae")
520
+ control_image = self.preprocessor(
521
+ image=image,
522
+ image_resolution=image_resolution,
523
+ detect_resolution=preprocess_resolution,
524
+ )
525
+ self.load_controlnet_weight("NormalBae")
526
+ results = self.run_pipe(
527
+ prompt=self.get_prompt(prompt, additional_prompt),
528
+ negative_prompt=negative_prompt,
529
+ control_image=control_image,
530
+ num_images=num_images,
531
+ num_steps=num_steps,
532
+ guidance_scale=guidance_scale,
533
+ seed=seed,
534
+ )
535
+ return [control_image, *results]
536
+
537
+ @torch.inference_mode()
538
+ def process_lineart(
539
+ self,
540
+ image: np.ndarray,
541
+ prompt: str,
542
+ additional_prompt: str,
543
+ negative_prompt: str,
544
+ num_images: int,
545
+ image_resolution: int,
546
+ preprocess_resolution: int,
547
+ num_steps: int,
548
+ guidance_scale: float,
549
+ seed: int,
550
+ preprocessor_name: str,
551
+ ) -> list[PIL.Image.Image]:
552
+ if image is None:
553
+ raise ValueError
554
+ if image_resolution > MAX_IMAGE_RESOLUTION:
555
+ raise ValueError
556
+ if num_images > MAX_NUM_IMAGES:
557
+ raise ValueError
558
+
559
+ if preprocessor_name in ["None", "None (anime)"]:
560
+ image = HWC3(image)
561
+ image = resize_image(image, resolution=image_resolution)
562
+ control_image = PIL.Image.fromarray(image)
563
+ elif preprocessor_name in ["Lineart", "Lineart coarse"]:
564
+ coarse = "coarse" in preprocessor_name
565
+ self.preprocessor.load("Lineart")
566
+ control_image = self.preprocessor(
567
+ image=image,
568
+ image_resolution=image_resolution,
569
+ detect_resolution=preprocess_resolution,
570
+ coarse=coarse,
571
+ )
572
+ elif preprocessor_name == "Lineart (anime)":
573
+ self.preprocessor.load("LineartAnime")
574
+ control_image = self.preprocessor(
575
+ image=image,
576
+ image_resolution=image_resolution,
577
+ detect_resolution=preprocess_resolution,
578
+ )
579
+ if "anime" in preprocessor_name:
580
+ self.load_controlnet_weight("lineart_anime")
581
+ else:
582
+ self.load_controlnet_weight("lineart")
583
+ results = self.run_pipe(
584
+ prompt=self.get_prompt(prompt, additional_prompt),
585
+ negative_prompt=negative_prompt,
586
+ control_image=control_image,
587
+ num_images=num_images,
588
+ num_steps=num_steps,
589
+ guidance_scale=guidance_scale,
590
+ seed=seed,
591
+ )
592
+ return [control_image, *results]
593
+
594
+ @torch.inference_mode()
595
+ def process_shuffle(
596
+ self,
597
+ image: np.ndarray,
598
+ prompt: str,
599
+ additional_prompt: str,
600
+ negative_prompt: str,
601
+ num_images: int,
602
+ image_resolution: int,
603
+ num_steps: int,
604
+ guidance_scale: float,
605
+ seed: int,
606
+ preprocessor_name: str,
607
+ ) -> list[PIL.Image.Image]:
608
+ if image is None:
609
+ raise ValueError
610
+ if image_resolution > MAX_IMAGE_RESOLUTION:
611
+ raise ValueError
612
+ if num_images > MAX_NUM_IMAGES:
613
+ raise ValueError
614
+
615
+ if preprocessor_name == "None":
616
+ image = HWC3(image)
617
+ image = resize_image(image, resolution=image_resolution)
618
+ control_image = PIL.Image.fromarray(image)
619
+ else:
620
+ self.preprocessor.load(preprocessor_name)
621
+ control_image = self.preprocessor(
622
+ image=image,
623
+ image_resolution=image_resolution,
624
+ )
625
+ self.load_controlnet_weight("shuffle")
626
+ results = self.run_pipe(
627
+ prompt=self.get_prompt(prompt, additional_prompt),
628
+ negative_prompt=negative_prompt,
629
+ control_image=control_image,
630
+ num_images=num_images,
631
+ num_steps=num_steps,
632
+ guidance_scale=guidance_scale,
633
+ seed=seed,
634
+ )
635
+ return [control_image, *results]
636
+
637
+ @torch.inference_mode()
638
+ def process_ip2p(
639
+ self,
640
+ image: np.ndarray,
641
+ prompt: str,
642
+ additional_prompt: str,
643
+ negative_prompt: str,
644
+ num_images: int,
645
+ image_resolution: int,
646
+ num_steps: int,
647
+ guidance_scale: float,
648
+ seed: int,
649
+ ) -> list[PIL.Image.Image]:
650
+ if image is None:
651
+ raise ValueError
652
+ if image_resolution > MAX_IMAGE_RESOLUTION:
653
+ raise ValueError
654
+ if num_images > MAX_NUM_IMAGES:
655
+ raise ValueError
656
+
657
+ image = HWC3(image)
658
+ image = resize_image(image, resolution=image_resolution)
659
+ control_image = PIL.Image.fromarray(image)
660
+ self.load_controlnet_weight("ip2p")
661
+ results = self.run_pipe(
662
+ prompt=self.get_prompt(prompt, additional_prompt),
663
+ negative_prompt=negative_prompt,
664
+ control_image=control_image,
665
+ num_images=num_images,
666
+ num_steps=num_steps,
667
+ guidance_scale=guidance_scale,
668
+ seed=seed,
669
+ )
670
+ return [control_image, *results]
preprocessor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from collections.abc import Callable
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ from controlnet_aux import (
11
+ CannyDetector,
12
+ ContentShuffleDetector,
13
+ HEDdetector,
14
+ LineartAnimeDetector,
15
+ LineartDetector,
16
+ MidasDetector,
17
+ MLSDdetector,
18
+ NormalBaeDetector,
19
+ OpenposeDetector,
20
+ PidiNetDetector,
21
+ )
22
+ from controlnet_aux.util import HWC3
23
+
24
+ from cv_utils import resize_image
25
+ from depth_estimator import DepthEstimator
26
+ from image_segmentor import ImageSegmentor
27
+
28
+
29
+ class Preprocessor:
30
+ MODEL_ID = "lllyasviel/Annotators"
31
+
32
+ def __init__(self) -> None:
33
+ self.model: Callable = None # type: ignore
34
+ self.name = ""
35
+
36
+ def load(self, name: str) -> None: # noqa: C901, PLR0912
37
+ if name == self.name:
38
+ return
39
+ if name == "HED":
40
+ self.model = HEDdetector.from_pretrained(self.MODEL_ID)
41
+ elif name == "Midas":
42
+ self.model = MidasDetector.from_pretrained(self.MODEL_ID)
43
+ elif name == "MLSD":
44
+ self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
45
+ elif name == "Openpose":
46
+ self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
47
+ elif name == "PidiNet":
48
+ self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
49
+ elif name == "NormalBae":
50
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
51
+ elif name == "Lineart":
52
+ self.model = LineartDetector.from_pretrained(self.MODEL_ID)
53
+ elif name == "LineartAnime":
54
+ self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
55
+ elif name == "Canny":
56
+ self.model = CannyDetector()
57
+ elif name == "ContentShuffle":
58
+ self.model = ContentShuffleDetector()
59
+ elif name == "DPT":
60
+ self.model = DepthEstimator()
61
+ elif name == "UPerNet":
62
+ self.model = ImageSegmentor()
63
+ else:
64
+ raise ValueError
65
+ torch.cuda.empty_cache()
66
+ gc.collect()
67
+ self.name = name
68
+
69
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
70
+ if self.name == "Canny":
71
+ if "detect_resolution" in kwargs:
72
+ detect_resolution = kwargs.pop("detect_resolution")
73
+ image = np.array(image)
74
+ image = HWC3(image)
75
+ image = resize_image(image, resolution=detect_resolution)
76
+ image = self.model(image, **kwargs)
77
+ return PIL.Image.fromarray(image)
78
+ if self.name == "Midas":
79
+ detect_resolution = kwargs.pop("detect_resolution", 512)
80
+ image_resolution = kwargs.pop("image_resolution", 512)
81
+ image = np.array(image)
82
+ image = HWC3(image)
83
+ image = resize_image(image, resolution=detect_resolution)
84
+ image = self.model(image, **kwargs)
85
+ image = HWC3(image)
86
+ image = resize_image(image, resolution=image_resolution)
87
+ return PIL.Image.fromarray(image)
88
+ return self.model(image, **kwargs)
settings.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "stable-diffusion-v1-5/stable-diffusion-v1-5")
6
+
7
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "3"))
8
+ DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "3")))
9
+ MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "768"))
10
+ DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "768")))
11
+
12
+ ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
13
+ SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
14
+
15
+ MAX_SEED = np.iinfo(np.int32).max
16
+
17
+ # setup CUDA
18
+ if os.getenv("CUDA_VISIBLE_DEVICES") is None:
19
+ os.environ["CUDA_VISIBLE_DEVICES"] = "7"
utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from settings import MAX_SEED
4
+
5
+
6
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
7
+ if randomize_seed:
8
+ seed = random.randint(0, MAX_SEED) # noqa: S311
9
+ return seed