Inmental commited on
Commit
a918b23
·
verified ·
1 Parent(s): 05ce455

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_sketch2image.py +382 -123
gradio_sketch2image.py CHANGED
@@ -1,123 +1,382 @@
1
- import random
2
- import numpy as np
3
- from PIL import Image
4
- from io import BytesIO
5
-
6
- import torch
7
- import torchvision.transforms.functional as F
8
- import gradio as gr
9
-
10
- from src.pix2pix_turbo import Pix2Pix_Turbo
11
-
12
- # Initialize the model
13
- model = Pix2Pix_Turbo("sketch_to_image_stochastic")
14
-
15
- # Define styles and related settings
16
- style_list = [
17
- {
18
- "name": "Cinematic",
19
- "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
20
- },
21
- {
22
- "name": "3D Model",
23
- "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
24
- },
25
- {
26
- "name": "Anime",
27
- "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
28
- },
29
- {
30
- "name": "Digital Art",
31
- "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
32
- },
33
- {
34
- "name": "Photographic",
35
- "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
36
- },
37
- {
38
- "name": "Pixel art",
39
- "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
40
- },
41
- {
42
- "name": "Fantasy art",
43
- "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
44
- },
45
- {
46
- "name": "Neonpunk",
47
- "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
48
- },
49
- {
50
- "name": "Manga",
51
- "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
52
- },
53
- ]
54
-
55
- styles = {style["name"]: style["prompt"] for style in style_list}
56
- STYLE_NAMES = list(styles.keys())
57
- DEFAULT_STYLE_NAME = "Fantasy art"
58
- MAX_SEED = np.iinfo(np.int32).max
59
-
60
- # API Function
61
- def process_image(image, prompt, style_name, seed, val_r):
62
- # Apply the selected style prompt template
63
- prompt_template = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
64
- prompt = prompt_template.replace("{prompt}", prompt)
65
-
66
- # Convert the image to RGB and tensor format
67
- image = image.convert("RGB")
68
- image_t = F.to_tensor(image) > 0.5
69
-
70
- # Set the random seed for reproducibility
71
- torch.manual_seed(seed)
72
-
73
- # Prepare the tensor input for the model
74
- with torch.no_grad():
75
- c_t = image_t.unsqueeze(0).cuda().float()
76
- B, C, H, W = c_t.shape
77
- noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
78
- output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
79
-
80
- # Convert the output tensor to a PIL image
81
- output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
82
-
83
- # Return the processed image
84
- return output_pil
85
-
86
- # Gradio Interface
87
- with gr.Blocks(css="style.css") as demo:
88
- gr.HTML("""
89
- <div style="text-align: center;">
90
- <h2>Image Processing API</h2>
91
- </div>
92
- """)
93
-
94
- with gr.Row():
95
- with gr.Column():
96
- image_input = gr.Image(label="Input Image", type="pil", tool="editor", interactive=True)
97
- prompt_input = gr.Textbox(label="Prompt", value="")
98
- style_dropdown = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
99
- seed_input = gr.Number(label="Seed", value=42)
100
- val_r_slider = gr.Slider(label="Sketch guidance (r)", minimum=0, maximum=1, step=0.01, value=0.4)
101
-
102
- process_button = gr.Button("Process Image")
103
-
104
- with gr.Column():
105
- image_output = gr.Image(label="Processed Image", interactive=False)
106
-
107
- # Linking the button to the API function
108
- process_button.click(
109
- fn=process_image,
110
- inputs=[image_input, prompt_input, style_dropdown, seed_input, val_r_slider],
111
- outputs=[image_output],
112
- )
113
-
114
- # Expose the API function for external access
115
- demo.load(
116
- fn=process_image,
117
- inputs=[gr.Image(type="pil"), gr.Textbox(), gr.Dropdown(choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME), gr.Number(value=42), gr.Slider(minimum=0, maximum=1, step=0.01, value=0.4)],
118
- outputs=gr.Image(),
119
- api_name="process_image",
120
- )
121
-
122
- if __name__ == "__main__":
123
- demo.queue().launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image
4
+ import base64
5
+ from io import BytesIO
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ import gradio as gr
10
+
11
+ from src.pix2pix_turbo import Pix2Pix_Turbo
12
+
13
+ model = Pix2Pix_Turbo("sketch_to_image_stochastic")
14
+
15
+ style_list = [
16
+ {
17
+ "name": "Cinematic",
18
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
19
+ },
20
+ {
21
+ "name": "3D Model",
22
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
23
+ },
24
+ {
25
+ "name": "Anime",
26
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
27
+ },
28
+ {
29
+ "name": "Digital Art",
30
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
31
+ },
32
+ {
33
+ "name": "Photographic",
34
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
35
+ },
36
+ {
37
+ "name": "Pixel art",
38
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
39
+ },
40
+ {
41
+ "name": "Fantasy art",
42
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
43
+ },
44
+ {
45
+ "name": "Neonpunk",
46
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
47
+ },
48
+ {
49
+ "name": "Manga",
50
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
51
+ },
52
+ ]
53
+
54
+ styles = {k["name"]: k["prompt"] for k in style_list}
55
+ STYLE_NAMES = list(styles.keys())
56
+ DEFAULT_STYLE_NAME = "Fantasy art"
57
+ MAX_SEED = np.iinfo(np.int32).max
58
+
59
+
60
+ def pil_image_to_data_uri(img, format="PNG"):
61
+ buffered = BytesIO()
62
+ img.save(buffered, format=format)
63
+ img_str = base64.b64encode(buffered.getvalue()).decode()
64
+ return f"data:image/{format.lower()};base64,{img_str}"
65
+
66
+
67
+ def run(image, prompt, prompt_template, style_name, seed, val_r):
68
+ print(f"prompt: {prompt}")
69
+ print("sketch updated")
70
+ if image is None:
71
+ ones = Image.new("L", (512, 512), 255)
72
+ temp_uri = pil_image_to_data_uri(ones)
73
+ return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
74
+ prompt = prompt_template.replace("{prompt}", prompt)
75
+ image = image.convert("RGB")
76
+ image_t = F.to_tensor(image) > 0.5
77
+ print(f"r_val={val_r}, seed={seed}")
78
+ with torch.no_grad():
79
+ c_t = image_t.unsqueeze(0).cuda().float()
80
+ torch.manual_seed(seed)
81
+ B, C, H, W = c_t.shape
82
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
83
+ output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
84
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
85
+ input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
86
+ output_image_uri = pil_image_to_data_uri(output_pil)
87
+ return (
88
+ output_pil,
89
+ gr.update(link=input_sketch_uri),
90
+ gr.update(link=output_image_uri),
91
+ )
92
+
93
+
94
+ def update_canvas(use_line, use_eraser):
95
+ if use_eraser:
96
+ _color = "#ffffff"
97
+ brush_size = 20
98
+ if use_line:
99
+ _color = "#000000"
100
+ brush_size = 4
101
+ return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
102
+
103
+
104
+ def upload_sketch(file):
105
+ _img = Image.open(file.name)
106
+ _img = _img.convert("L")
107
+ return gr.update(value=_img, source="upload", interactive=True)
108
+
109
+
110
+ scripts = """
111
+ async () => {
112
+ globalThis.theSketchDownloadFunction = () => {
113
+ console.log("test")
114
+ var link = document.createElement("a");
115
+ dataUri = document.getElementById('download_sketch').href
116
+ link.setAttribute("href", dataUri)
117
+ link.setAttribute("download", "sketch.png")
118
+ document.body.appendChild(link); // Required for Firefox
119
+ link.click();
120
+ document.body.removeChild(link); // Clean up
121
+
122
+ // also call the output download function
123
+ theOutputDownloadFunction();
124
+ return false
125
+ }
126
+
127
+ globalThis.theOutputDownloadFunction = () => {
128
+ console.log("test output download function")
129
+ var link = document.createElement("a");
130
+ dataUri = document.getElementById('download_output').href
131
+ link.setAttribute("href", dataUri);
132
+ link.setAttribute("download", "output.png");
133
+ document.body.appendChild(link); // Required for Firefox
134
+ link.click();
135
+ document.body.removeChild(link); // Clean up
136
+ return false
137
+ }
138
+
139
+ globalThis.UNDO_SKETCH_FUNCTION = () => {
140
+ console.log("undo sketch function")
141
+ var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
142
+ // Create a new 'click' event
143
+ var event = new MouseEvent('click', {
144
+ 'view': window,
145
+ 'bubbles': true,
146
+ 'cancelable': true
147
+ });
148
+ button_undo.dispatchEvent(event);
149
+ }
150
+
151
+ globalThis.DELETE_SKETCH_FUNCTION = () => {
152
+ console.log("delete sketch function")
153
+ var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
154
+ // Create a new 'click' event
155
+ var event = new MouseEvent('click', {
156
+ 'view': window,
157
+ 'bubbles': true,
158
+ 'cancelable': true
159
+ });
160
+ button_del.dispatchEvent(event);
161
+ }
162
+
163
+ globalThis.togglePencil = () => {
164
+ el_pencil = document.getElementById('my-toggle-pencil');
165
+ el_pencil.classList.toggle('clicked');
166
+ // simulate a click on the gradio button
167
+ btn_gradio = document.querySelector("#cb-line > label > input");
168
+ var event = new MouseEvent('click', {
169
+ 'view': window,
170
+ 'bubbles': true,
171
+ 'cancelable': true
172
+ });
173
+ btn_gradio.dispatchEvent(event);
174
+ if (el_pencil.classList.contains('clicked')) {
175
+ document.getElementById('my-toggle-eraser').classList.remove('clicked');
176
+ document.getElementById('my-div-pencil').style.backgroundColor = "gray";
177
+ document.getElementById('my-div-eraser').style.backgroundColor = "white";
178
+ }
179
+ else {
180
+ document.getElementById('my-toggle-eraser').classList.add('clicked');
181
+ document.getElementById('my-div-pencil').style.backgroundColor = "white";
182
+ document.getElementById('my-div-eraser').style.backgroundColor = "gray";
183
+ }
184
+ }
185
+
186
+ globalThis.toggleEraser = () => {
187
+ element = document.getElementById('my-toggle-eraser');
188
+ element.classList.toggle('clicked');
189
+ // simulate a click on the gradio button
190
+ btn_gradio = document.querySelector("#cb-eraser > label > input");
191
+ var event = new MouseEvent('click', {
192
+ 'view': window,
193
+ 'bubbles': true,
194
+ 'cancelable': true
195
+ });
196
+ btn_gradio.dispatchEvent(event);
197
+ if (element.classList.contains('clicked')) {
198
+ document.getElementById('my-toggle-pencil').classList.remove('clicked');
199
+ document.getElementById('my-div-pencil').style.backgroundColor = "white";
200
+ document.getElementById('my-div-eraser').style.backgroundColor = "gray";
201
+ }
202
+ else {
203
+ document.getElementById('my-toggle-pencil').classList.add('clicked');
204
+ document.getElementById('my-div-pencil').style.backgroundColor = "gray";
205
+ document.getElementById('my-div-eraser').style.backgroundColor = "white";
206
+ }
207
+ }
208
+ }
209
+ """
210
+
211
+ with gr.Blocks(css="style.css") as demo:
212
+
213
+ gr.HTML(
214
+ """
215
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
216
+ <div>
217
+ <h2><a href="https://github.com/GaParmar/img2img-turbo">One-Step Image Translation with Text-to-Image Models</a></h2>
218
+ <div>
219
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
220
+ <a href='https://gauravparmar.com/'>Gaurav Parmar, </a>
221
+ &nbsp;
222
+ <a href='https://taesung.me/'> Taesung Park,</a>
223
+ &nbsp;
224
+ <a href='https://www.cs.cmu.edu/~srinivas/'>Srinivasa Narasimhan, </a>
225
+ &nbsp;
226
+ <a href='https://www.cs.cmu.edu/~junyanz/'> Jun-Yan Zhu </a>
227
+ </div>
228
+ </div>
229
+ </br>
230
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
231
+ <a href='https://arxiv.org/abs/2403.12036'>
232
+ <img src="https://img.shields.io/badge/arXiv-2403.12036-red">
233
+ </a>
234
+ &nbsp;
235
+ <a href='https://github.com/GaParmar/img2img-turbo'>
236
+ <img src='https://img.shields.io/badge/github-%23121011.svg'>
237
+ </a>
238
+ &nbsp;
239
+ <a href='https://github.com/GaParmar/img2img-turbo/blob/main/LICENSE'>
240
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
241
+ </a>
242
+ </div>
243
+ </div>
244
+ </div>
245
+ <div>
246
+ </br>
247
+ </div>
248
+ """
249
+ )
250
+
251
+ # these are hidden buttons that are used to trigger the canvas changes
252
+ line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
253
+ eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
254
+ with gr.Row(elem_id="main_row"):
255
+ with gr.Column(elem_id="column_input"):
256
+ gr.Markdown("## INPUT", elem_id="input_header")
257
+ image = gr.Image(
258
+ source="canvas",
259
+ tool="color-sketch",
260
+ type="pil",
261
+ image_mode="L",
262
+ invert_colors=True,
263
+ shape=(512, 512),
264
+ brush_radius=4,
265
+ height=440,
266
+ width=440,
267
+ brush_color="#000000",
268
+ interactive=True,
269
+ show_download_button=True,
270
+ elem_id="input_image",
271
+ show_label=False,
272
+ )
273
+ download_sketch = gr.Button(
274
+ "Download sketch", scale=1, elem_id="download_sketch"
275
+ )
276
+
277
+ gr.HTML(
278
+ """
279
+ <div class="button-row">
280
+ <div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
281
+ <div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
282
+ <div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
283
+ <div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
284
+ <div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
285
+ </div>
286
+ """
287
+ )
288
+ # gr.Markdown("## Prompt", elem_id="tools_header")
289
+ prompt = gr.Textbox(label="Prompt", value="", show_label=True)
290
+ with gr.Row():
291
+ style = gr.Dropdown(
292
+ label="Style",
293
+ choices=STYLE_NAMES,
294
+ value=DEFAULT_STYLE_NAME,
295
+ scale=1,
296
+ )
297
+ prompt_temp = gr.Textbox(
298
+ label="Prompt Style Template",
299
+ value=styles[DEFAULT_STYLE_NAME],
300
+ scale=2,
301
+ max_lines=1,
302
+ )
303
+
304
+ with gr.Row():
305
+ val_r = gr.Slider(
306
+ label="Sketch guidance: ",
307
+ show_label=True,
308
+ minimum=0,
309
+ maximum=1,
310
+ value=0.4,
311
+ step=0.01,
312
+ scale=3,
313
+ )
314
+ seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
315
+ randomize_seed = gr.Button("Random", scale=1, min_width=50)
316
+
317
+ with gr.Column(elem_id="column_process", min_width=50, scale=0.4):
318
+ gr.Markdown("## pix2pix-turbo", elem_id="description")
319
+ run_button = gr.Button("Run", min_width=50)
320
+
321
+ with gr.Column(elem_id="column_output"):
322
+ gr.Markdown("## OUTPUT", elem_id="output_header")
323
+ result = gr.Image(
324
+ label="Result",
325
+ height=440,
326
+ width=440,
327
+ elem_id="output_image",
328
+ show_label=False,
329
+ show_download_button=True,
330
+ )
331
+ download_output = gr.Button("Download output", elem_id="download_output")
332
+ gr.Markdown("### Instructions")
333
+ gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
334
+ gr.Markdown("**2**. Start sketching")
335
+ gr.Markdown("**3**. Change the image style using a style template")
336
+ gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
337
+ gr.Markdown("**5**. Try different seeds to generate different results")
338
+
339
+ eraser.change(
340
+ fn=lambda x: gr.update(value=not x),
341
+ inputs=[eraser],
342
+ outputs=[line],
343
+ queue=False,
344
+ api_name=False,
345
+ ).then(update_canvas, [line, eraser], [image])
346
+ line.change(
347
+ fn=lambda x: gr.update(value=not x),
348
+ inputs=[line],
349
+ outputs=[eraser],
350
+ queue=False,
351
+ api_name=False,
352
+ ).then(update_canvas, [line, eraser], [image])
353
+
354
+ demo.load(None, None, None, _js=scripts)
355
+ randomize_seed.click(
356
+ lambda x: random.randint(0, MAX_SEED),
357
+ inputs=[],
358
+ outputs=seed,
359
+ queue=False,
360
+ api_name=False,
361
+ )
362
+ inputs = [image, prompt, prompt_temp, style, seed, val_r]
363
+ outputs = [result, download_sketch, download_output]
364
+ prompt.submit(fn=run, inputs=inputs, outputs=outputs, api_name=False)
365
+ style.change(
366
+ lambda x: styles[x],
367
+ inputs=[style],
368
+ outputs=[prompt_temp],
369
+ queue=False,
370
+ api_name=False,
371
+ ).then(
372
+ fn=run,
373
+ inputs=inputs,
374
+ outputs=outputs,
375
+ api_name=False,
376
+ )
377
+ val_r.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
378
+ run_button.click(fn=run, inputs=inputs, outputs=outputs, api_name=False)
379
+ image.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
380
+
381
+ if __name__ == "__main__":
382
+ demo.queue().launch(debug=True, share=True)