FrankFacundo commited on
Commit
bcb79da
·
1 Parent(s): 5ed0a9c

controlnet

Browse files
Files changed (2) hide show
  1. app.py +376 -4
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,7 +1,379 @@
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import gradio as gr
3
+ import spaces
4
+ import os
5
+ import random
6
 
7
+ import subprocess
8
+ import torch
9
+ from PIL import Image
10
+ import cv2
11
+ from huggingface_hub import login
12
+ from diffusers import FluxControlNetPipeline, FluxControlNetModel
13
+ from diffusers.models import FluxMultiControlNetModel
14
 
15
+ import warnings
16
+ from typing import Tuple
17
+
18
+ """
19
+ FLUX‑1 ControlNet demo
20
+ ----------------------
21
+ This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload
22
+ slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model.
23
+
24
+ Key points
25
+ ~~~~~~~~~~
26
+ * Single *control image* input (left).
27
+ * *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right).
28
+ * *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength.
29
+ * Seed handling with optional randomisation.
30
+ * Advanced sliders for *Guidance scale* and *Inference steps*.
31
+ * Works on CUDA (bfloat16) or CPU (float32).
32
+ * Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the
33
+ other modes).
34
+
35
+ Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call
36
+ `login("<YOUR_HF_TOKEN>")` explicitly.
37
+ """
38
+
39
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
40
+
41
+ # --------------------------------------------------
42
+ # Model & pipeline setup
43
+ # --------------------------------------------------
44
+ HF_TOKEN = os.getenv("HF_TOKEN_NEW")
45
+ login(HF_TOKEN)
46
+ # If you prefer to hard‑code the token, uncomment:
47
+ # login("hf_your_token_here")
48
+
49
+ BASE_MODEL = "black-forest-labs/FLUX.1-dev"
50
+ CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"
51
+
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
54
+
55
+ print(1)
56
+ controlnet_single = FluxControlNetModel.from_pretrained(
57
+ CONTROLNET_MODEL, torch_dtype=dtype
58
+ )
59
+ print(2)
60
+ controlnet = FluxMultiControlNetModel([controlnet_single])
61
+
62
+ print(3)
63
+ pipe = FluxControlNetPipeline.from_pretrained(
64
+ BASE_MODEL, controlnet=controlnet, torch_dtype=dtype
65
+ ).to(device)
66
+ print(4)
67
+ pipe.set_progress_bar_config(disable=True)
68
+ print(5)
69
+
70
+ # --------------------------------------------------
71
+ # UI ‑> model value mapping
72
+ # --------------------------------------------------
73
+ MODE_MAPPING = {
74
+ "canny": 0,
75
+ "tile": 1,
76
+ "depth": 2,
77
+ "blur": 3,
78
+ "pose": 4,
79
+ "gray": 5,
80
+ "low quality": 6,
81
+ }
82
+
83
+ MAX_SEED = 100
84
+
85
+ # -----------------------------------------------------------------------------
86
+ # Preview helpers – one small, self‑contained function per mode
87
+ # -----------------------------------------------------------------------------
88
+
89
+
90
+ def _preview_canny(
91
+ pil_img: Image.Image, canny_threshold_1: int, canny_threshold_2: int
92
+ ) -> Image.Image:
93
+ """Fast Canny‑edge preview (already implemented)."""
94
+
95
+ arr = np.array(pil_img.convert("RGB"))
96
+ edges = cv2.Canny(arr, threshold1=canny_threshold_1, threshold2=canny_threshold_2)
97
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
98
+ return Image.fromarray(edges_rgb)
99
+
100
+
101
+ # ――― tile ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
102
+
103
+
104
+ def _preview_tile(pil_img: Image.Image, grid: Tuple[int, int] = (2, 2)) -> Image.Image:
105
+ """Replicates *pil_img* into an *n×m* tiled grid (default 2×2).
106
+
107
+ This offers a quick visual hint of what a *tiling* control mode will do
108
+ (repeatable textures, etc.)."""
109
+
110
+ cols, rows = grid
111
+ img_rgb = pil_img.convert("RGB")
112
+ w, h = img_rgb.size
113
+ tiled = Image.new("RGB", (w * cols, h * rows))
114
+ for c in range(cols):
115
+ for r in range(rows):
116
+ tiled.paste(img_rgb, (c * w, r * h))
117
+ return tiled
118
+
119
+
120
+ # ――― depth ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
121
+
122
+
123
+ def _preview_depth(pil_img: Image.Image) -> Image.Image:
124
+ """Very rough *depth* proxy using the Laplacian and a colormap.
125
+
126
+ ▸ Convert to gray
127
+ ▸ Run Laplacian to highlight depth‑like gradients
128
+ ▸ Apply a TURBO colormap to mimic depth heat‑map appearance"""
129
+
130
+ gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
131
+ lap = cv2.Laplacian(gray, cv2.CV_16S, ksize=3)
132
+ depth = cv2.convertScaleAbs(lap)
133
+ depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO)
134
+ return Image.fromarray(depth_color)
135
+
136
+
137
+ # ――― blur ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
138
+
139
+
140
+ def _preview_blur(pil_img: Image.Image, ksize: int = 15) -> Image.Image:
141
+ """Gaussian blur preview.
142
+ A single, relatively large kernel is enough for UI illustration."""
143
+
144
+ if ksize % 2 == 0:
145
+ ksize += 1 # kernel must be odd
146
+ blurred = cv2.GaussianBlur(np.array(pil_img), (ksize, ksize), sigmaX=0)
147
+ return Image.fromarray(blurred)
148
+
149
+
150
+ # ――― pose ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
151
+
152
+
153
+ def _preview_pose(pil_img: Image.Image) -> Image.Image:
154
+ """Attempt a lightweight 2‑D pose overlay using *mediapipe* if available.
155
+
156
+ If *mediapipe* is not installed (or CPU inference fails), we gracefully
157
+ fallback to an edge‑map preview so the UI never crashes."""
158
+
159
+ try:
160
+ import mediapipe as mp # type: ignore
161
+
162
+ mp_pose = mp.solutions.pose
163
+ mp_drawing = mp.solutions.drawing_utils
164
+
165
+ img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
166
+ with mp_pose.Pose(static_image_mode=True) as pose_estimator:
167
+ results = pose_estimator.process(
168
+ img_bgr[..., ::-1]
169
+ ) # Mediapipe expects RGB
170
+
171
+ annotated = img_bgr.copy()
172
+ if results.pose_landmarks:
173
+ mp_drawing.draw_landmarks(
174
+ annotated, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
175
+ )
176
+ annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
177
+ return Image.fromarray(annotated_rgb)
178
+
179
+ except Exception as exc: # pragma: no cover – any import / runtime error
180
+ warnings.warn(
181
+ f"Pose preview failed ({exc!s}); falling back to Canny.", RuntimeWarning
182
+ )
183
+ # Return an edge map as a sensible fallback rather than exploding the UI
184
+ return _preview_canny(pil_img, 100, 200)
185
+
186
+
187
+ # ――― gray ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
188
+
189
+
190
+ def _preview_gray(pil_img: Image.Image) -> Image.Image:
191
+ """Simple grayscale conversion, but keep a 3‑channel RGB image so the UI
192
+ widget pipeline stays consistent."""
193
+
194
+ gray = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2GRAY)
195
+ gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
196
+ return Image.fromarray(gray_rgb)
197
+
198
+
199
+ # ――― low quality ――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
200
+
201
+
202
+ def _preview_low_quality(pil_img: Image.Image, factor: int = 8) -> Image.Image:
203
+ """Mimic a low‑quality thumbnail: aggressively downsample then upscale.
204
+
205
+ The default *factor* (8×) is chosen to make artefacts obvious."""
206
+
207
+ img_rgb = pil_img.convert("RGB")
208
+ w, h = img_rgb.size
209
+ small = img_rgb.resize((max(1, w // factor), max(1, h // factor)), Image.BILINEAR)
210
+ low_q = small.resize(
211
+ (w, h), Image.NEAREST
212
+ ) # upsample w/ Nearest to exaggerate blocks
213
+ return low_q
214
+
215
+
216
+ # -----------------------------------------------------------------------------
217
+ # Master dispatch
218
+ # -----------------------------------------------------------------------------
219
+
220
+
221
+ def _make_preview(
222
+ control_image: Image.Image,
223
+ mode: str,
224
+ canny_threshold_1: int = 100,
225
+ canny_threshold_2: int = 200,
226
+ ) -> Image.Image:
227
+ """Return a *quick‑n‑dirty* preview image for the requested *mode*.
228
+
229
+ Parameters
230
+ ----------
231
+ control_image : PIL.Image
232
+ The input image selected by the user.
233
+ mode : str
234
+ One of the keys of :data:`MODE_MAPPING`.
235
+ canny_threshold_1 / 2 : int, optional
236
+ Only used if *mode* is "canny" (passed straight to OpenCV Canny).
237
+ """
238
+
239
+ mode = mode.lower()
240
+ if mode not in MODE_MAPPING:
241
+ warnings.warn(f"Unknown preview mode '{mode}'. Returning untouched image.")
242
+ return control_image
243
+
244
+ if mode == "canny":
245
+ return _preview_canny(control_image, canny_threshold_1, canny_threshold_2)
246
+ if mode == "tile":
247
+ return _preview_tile(control_image)
248
+ if mode == "depth":
249
+ return _preview_depth(control_image)
250
+ if mode == "blur":
251
+ return _preview_blur(control_image)
252
+ if mode == "pose":
253
+ return _preview_pose(control_image)
254
+ if mode == "gray":
255
+ return _preview_gray(control_image)
256
+ if mode == "low quality":
257
+ return _preview_low_quality(control_image)
258
+
259
+ # Fallback – should never happen due to early mode check
260
+ return control_image
261
+
262
+
263
+ # --------------------------------------------------
264
+ # Inference function
265
+ # --------------------------------------------------
266
+
267
+
268
+ @spaces.GPU
269
+ def infer(
270
+ control_image: Image.Image,
271
+ prompt: str,
272
+ mode: str,
273
+ control_strength: float,
274
+ seed: int,
275
+ randomize_seed: bool,
276
+ guidance_scale: float,
277
+ num_inference_steps: int,
278
+ canny_threshold_1: int,
279
+ canny_threshold_2: int,
280
+ ):
281
+ if control_image is None:
282
+ raise gr.Error("Please upload a control image first.")
283
+
284
+ if randomize_seed:
285
+ seed = random.randint(0, MAX_SEED)
286
+
287
+ gen = torch.Generator(device).manual_seed(seed)
288
+ w, h = control_image.size
289
+
290
+ preprocessed = _make_preview(
291
+ control_image, mode, canny_threshold_1, canny_threshold_2
292
+ )
293
+
294
+ result = pipe(
295
+ prompt=prompt,
296
+ control_image=[preprocessed],
297
+ control_mode=[MODE_MAPPING[mode]],
298
+ width=w,
299
+ height=h,
300
+ controlnet_conditioning_scale=[control_strength],
301
+ num_inference_steps=num_inference_steps,
302
+ guidance_scale=guidance_scale,
303
+ generator=gen,
304
+ ).images[0]
305
+
306
+ return result, seed, preprocessed
307
+
308
+
309
+ # --------------------------------------------------
310
+ # Gradio UI
311
+ # --------------------------------------------------
312
+ css = """#wrapper {max-width: 960px; margin: 0 auto;}"""
313
+ with gr.Blocks(css=css, elem_id="wrapper") as demo:
314
+ gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro by Frank")
315
+ gr.Markdown(
316
+ "A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. "
317
+ + "Recommended strengths: *canny 0.76*. Long prompts usually help."
318
+ )
319
+
320
+ # ------------ Image panel row ------------
321
+ with gr.Row():
322
+ control_image = gr.Image(
323
+ label="Upload animage",
324
+ type="pil",
325
+ height=512 + 256,
326
+ )
327
+ result_image = gr.Image(label="Result", height=512 + 256)
328
+ preview_image = gr.Image(label="Pre‑processed Cond", height=512 + 256)
329
+
330
+ # ------------ Prompt ------------
331
+ prompt_txt = gr.Textbox(label="Prompt", value="White background", lines=1)
332
+
333
+ # ------------ ControlNet settings ------------
334
+ with gr.Row():
335
+ with gr.Column():
336
+ gr.Markdown("### ControlNet")
337
+ mode_radio = gr.Radio(
338
+ choices=list(MODE_MAPPING.keys()), value="canny", label="Mode"
339
+ )
340
+ strength_slider = gr.Slider(
341
+ 0.0, 1.0, value=0.76, step=0.01, label="control strength"
342
+ )
343
+ gr.Markdown("### Preprocess")
344
+ canny_threshold_1 = gr.Slider(
345
+ 0, 500, step=1, value=100, label="Canny threshold 1"
346
+ )
347
+ canny_threshold_2 = gr.Slider(
348
+ 0, 500, step=1, value=200, label="Canny threshold 2"
349
+ )
350
+
351
+ with gr.Column():
352
+ seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
353
+ randomize_chk = gr.Checkbox(label="Randomize seed", value=False)
354
+ guidance_slider = gr.Slider(
355
+ 0.0, 10.0, step=0.1, value=3.5, label="Guidance scale"
356
+ )
357
+ steps_slider = gr.Slider(1, 50, step=1, value=50, label="Inference steps")
358
+
359
+ submit_btn = gr.Button("Submit")
360
+
361
+ submit_btn.click(
362
+ fn=infer,
363
+ inputs=[
364
+ control_image,
365
+ prompt_txt,
366
+ mode_radio,
367
+ strength_slider,
368
+ seed_slider,
369
+ randomize_chk,
370
+ guidance_slider,
371
+ steps_slider,
372
+ canny_threshold_1,
373
+ canny_threshold_2,
374
+ ],
375
+ outputs=[result_image, seed_slider, preview_image],
376
+ )
377
+
378
+ if __name__ == "__main__":
379
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ invisible_watermark
4
+ torch
5
+ transformers
6
+ xformers
7
+ sentencepiece==0.2.0
8
+ opencv-python