dawood HF Staff commited on
Commit
44c5c55
·
verified ·
1 Parent(s): e80fdf0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +434 -0
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import gc
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import json
8
+ import config
9
+ import utils
10
+ import logging
11
+ from PIL import Image, PngImagePlugin
12
+ from datetime import datetime
13
+ from diffusers.models import AutoencoderKL
14
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ DESCRIPTION = "Animagine XL 3.1"
20
+ if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
22
+ IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "0"
25
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
26
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
27
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
+
31
+ MODEL = os.getenv(
32
+ "MODEL",
33
+ "cagliostrolab/animagine-xl-3.1",
34
+ )
35
+
36
+ torch.backends.cudnn.deterministic = True
37
+ torch.backends.cudnn.benchmark = False
38
+ torch.backends.cuda.matmul.allow_tf32 = True
39
+
40
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
+
42
+
43
+ def load_pipeline(model_name):
44
+ vae = AutoencoderKL.from_pretrained(
45
+ "madebyollin/sdxl-vae-fp16-fix",
46
+ torch_dtype=torch.float16,
47
+ )
48
+ pipeline = (
49
+ StableDiffusionXLPipeline.from_single_file
50
+ if MODEL.endswith(".safetensors")
51
+ else StableDiffusionXLPipeline.from_pretrained
52
+ )
53
+
54
+ pipe = pipeline(
55
+ model_name,
56
+ vae=vae,
57
+ torch_dtype=torch.float16,
58
+ custom_pipeline="lpw_stable_diffusion_xl",
59
+ use_safetensors=True,
60
+ add_watermarker=False,
61
+ use_auth_token=HF_TOKEN,
62
+ )
63
+
64
+ pipe.to(device)
65
+ return pipe
66
+
67
+
68
+ @spaces.GPU
69
+ def generate(
70
+ prompt: str,
71
+ negative_prompt: str = "",
72
+ seed: int = 0,
73
+ custom_width: int = 1024,
74
+ custom_height: int = 1024,
75
+ guidance_scale: float = 7.0,
76
+ num_inference_steps: int = 28,
77
+ sampler: str = "Euler a",
78
+ aspect_ratio_selector: str = "896 x 1152",
79
+ style_selector: str = "(None)",
80
+ quality_selector: str = "Standard v3.1",
81
+ use_upscaler: bool = False,
82
+ upscaler_strength: float = 0.55,
83
+ upscale_by: float = 1.5,
84
+ add_quality_tags: bool = True,
85
+ progress=gr.Progress(track_tqdm=True),
86
+ ):
87
+ generator = utils.seed_everything(seed)
88
+
89
+ width, height = utils.aspect_ratio_handler(
90
+ aspect_ratio_selector,
91
+ custom_width,
92
+ custom_height,
93
+ )
94
+
95
+ prompt = utils.add_wildcard(prompt, wildcard_files)
96
+
97
+ prompt, negative_prompt = utils.preprocess_prompt(
98
+ quality_prompt, quality_selector, prompt, negative_prompt, add_quality_tags
99
+ )
100
+ prompt, negative_prompt = utils.preprocess_prompt(
101
+ styles, style_selector, prompt, negative_prompt
102
+ )
103
+
104
+ width, height = utils.preprocess_image_dimensions(width, height)
105
+
106
+ backup_scheduler = pipe.scheduler
107
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
108
+
109
+ if use_upscaler:
110
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
111
+ metadata = {
112
+ "prompt": prompt,
113
+ "negative_prompt": negative_prompt,
114
+ "resolution": f"{width} x {height}",
115
+ "guidance_scale": guidance_scale,
116
+ "num_inference_steps": num_inference_steps,
117
+ "seed": seed,
118
+ "sampler": sampler,
119
+ "sdxl_style": style_selector,
120
+ "add_quality_tags": add_quality_tags,
121
+ "quality_tags": quality_selector,
122
+ }
123
+
124
+ if use_upscaler:
125
+ new_width = int(width * upscale_by)
126
+ new_height = int(height * upscale_by)
127
+ metadata["use_upscaler"] = {
128
+ "upscale_method": "nearest-exact",
129
+ "upscaler_strength": upscaler_strength,
130
+ "upscale_by": upscale_by,
131
+ "new_resolution": f"{new_width} x {new_height}",
132
+ }
133
+ else:
134
+ metadata["use_upscaler"] = None
135
+ metadata["Model"] = {
136
+ "Model": DESCRIPTION,
137
+ "Model hash": "e3c47aedb0",
138
+ }
139
+
140
+ logger.info(json.dumps(metadata, indent=4))
141
+
142
+ try:
143
+ if use_upscaler:
144
+ latents = pipe(
145
+ prompt=prompt,
146
+ negative_prompt=negative_prompt,
147
+ width=width,
148
+ height=height,
149
+ guidance_scale=guidance_scale,
150
+ num_inference_steps=num_inference_steps,
151
+ generator=generator,
152
+ output_type="latent",
153
+ ).images
154
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
155
+ images = upscaler_pipe(
156
+ prompt=prompt,
157
+ negative_prompt=negative_prompt,
158
+ image=upscaled_latents,
159
+ guidance_scale=guidance_scale,
160
+ num_inference_steps=num_inference_steps,
161
+ strength=upscaler_strength,
162
+ generator=generator,
163
+ output_type="pil",
164
+ ).images
165
+ else:
166
+ images = pipe(
167
+ prompt=prompt,
168
+ negative_prompt=negative_prompt,
169
+ width=width,
170
+ height=height,
171
+ guidance_scale=guidance_scale,
172
+ num_inference_steps=num_inference_steps,
173
+ generator=generator,
174
+ output_type="pil",
175
+ ).images
176
+
177
+ if images:
178
+ image_paths = [
179
+ utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB)
180
+ for image in images
181
+ ]
182
+
183
+ for image_path in image_paths:
184
+ logger.info(f"Image saved as {image_path} with metadata")
185
+
186
+ return image_paths, metadata
187
+ except Exception as e:
188
+ logger.exception(f"An error occurred: {e}")
189
+ raise
190
+ finally:
191
+ if use_upscaler:
192
+ del upscaler_pipe
193
+ pipe.scheduler = backup_scheduler
194
+ utils.free_memory()
195
+
196
+
197
+ if torch.cuda.is_available():
198
+ pipe = load_pipeline(MODEL)
199
+ logger.info("Loaded on Device!")
200
+ else:
201
+ pipe = None
202
+
203
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
204
+ quality_prompt = {
205
+ k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.quality_prompt_list
206
+ }
207
+
208
+ wildcard_files = utils.load_wildcard_files("wildcard")
209
+
210
+ with gr.Blocks(css="style.css", theme="NoCrypt/miku@1.2.1") as demo:
211
+ title = gr.HTML(
212
+ f"""<h1><span>{DESCRIPTION}</span></h1>""",
213
+ elem_id="title",
214
+ )
215
+ gr.Markdown(
216
+ f"""Gradio demo for [cagliostrolab/animagine-xl-3.1](https://huggingface.co/cagliostrolab/animagine-xl-3.1)""",
217
+ elem_id="subtitle",
218
+ )
219
+ gr.HTML(
220
+ f"""
221
+ <a href="https://discord.gg/xhAYSh2Psu" target="_blank" class="discord-btn" target="_blank">
222
+ <div class="discord-content">
223
+ <span class="discord-icon">
224
+ <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-discord" viewBox="0 0 16 16">
225
+ <path d="M13.545 2.907a13.2 13.2 0 0 0-3.257-1.011.05.05 0 0 0-.052.025c-.141.25-.297.577-.406.833a12.2 12.2 0 0 0-3.658 0 8 8 0 0 0-.412-.833.05.05 0 0 0-.052-.025c-1.125.194-2.22.534-3.257 1.011a.04.04 0 0 0-.021.018C.356 6.024-.213 9.047.066 12.032q.003.022.021.037a13.3 13.3 0 0 0 3.995 2.02.05.05 0 0 0 .056-.019q.463-.63.818-1.329a.05.05 0 0 0-.01-.059l-.018-.011a9 9 0 0 1-1.248-.595.05.05 0 0 1-.02-.066l.015-.019q.127-.095.248-.195a.05.05 0 0 1 .051-.007c2.619 1.196 5.454 1.196 8.041 0a.05.05 0 0 1 .053.007q.121.1.248.195a.05.05 0 0 1-.004.085 8 8 0 0 1-1.249.594.05.05 0 0 0-.03.03.05.05 0 0 0 .003.041c.24.465.515.909.817 1.329a.05.05 0 0 0 .056.019 13.2 13.2 0 0 0 4.001-2.02.05.05 0 0 0 .021-.037c.334-3.451-.559-6.449-2.366-9.106a.03.03 0 0 0-.02-.019m-8.198 7.307c-.789 0-1.438-.724-1.438-1.612s.637-1.613 1.438-1.613c.807 0 1.45.73 1.438 1.613 0 .888-.637 1.612-1.438 1.612m5.316 0c-.788 0-1.438-.724-1.438-1.612s.637-1.613 1.438-1.613c.807 0 1.451.73 1.438 1.613 0 .888-.631 1.612-1.438 1.612"/>
226
+ </svg>
227
+ </span>
228
+ <span class="discord-text">Discord</span>
229
+ </div>
230
+ </a>
231
+ """
232
+ )
233
+ gr.DuplicateButton(
234
+ value="Duplicate Space for private use",
235
+ elem_id="duplicate-button",
236
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
237
+ )
238
+
239
+ # Create sidebar for advanced settings
240
+ with gr.Sidebar():
241
+ # Quality Tags
242
+ with gr.Group():
243
+ add_quality_tags = gr.Checkbox(
244
+ label="Add Quality Tags", value=True
245
+ )
246
+ quality_selector = gr.Dropdown(
247
+ label="Quality Tags Presets",
248
+ interactive=True,
249
+ choices=list(quality_prompt.keys()),
250
+ value="Standard v3.1",
251
+ )
252
+
253
+ # Style Preset
254
+ with gr.Group():
255
+ style_selector = gr.Radio(
256
+ label="Style Preset",
257
+ container=True,
258
+ interactive=True,
259
+ choices=list(styles.keys()),
260
+ value="(None)",
261
+ )
262
+
263
+ # Aspect Ratio
264
+ with gr.Group():
265
+ aspect_ratio_selector = gr.Radio(
266
+ label="Aspect Ratio",
267
+ choices=config.aspect_ratios,
268
+ value="896 x 1152",
269
+ container=True,
270
+ )
271
+
272
+ # Custom Resolution (initially hidden)
273
+ with gr.Group(visible=False) as custom_resolution:
274
+ with gr.Row():
275
+ custom_width = gr.Slider(
276
+ label="Width",
277
+ minimum=MIN_IMAGE_SIZE,
278
+ maximum=MAX_IMAGE_SIZE,
279
+ step=8,
280
+ value=1024,
281
+ )
282
+ custom_height = gr.Slider(
283
+ label="Height",
284
+ minimum=MIN_IMAGE_SIZE,
285
+ maximum=MAX_IMAGE_SIZE,
286
+ step=8,
287
+ value=1024,
288
+ )
289
+
290
+ # Upscaler options
291
+ with gr.Group():
292
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
293
+ upscaler_strength = gr.Slider(
294
+ label="Strength",
295
+ minimum=0,
296
+ maximum=1,
297
+ step=0.05,
298
+ value=0.55,
299
+ visible=False,
300
+ )
301
+ upscale_by = gr.Slider(
302
+ label="Upscale by",
303
+ minimum=1,
304
+ maximum=1.5,
305
+ step=0.1,
306
+ value=1.5,
307
+ visible=False,
308
+ )
309
+
310
+ # Sampler selection
311
+ with gr.Group():
312
+ sampler = gr.Dropdown(
313
+ label="Sampler",
314
+ choices=config.sampler_list,
315
+ interactive=True,
316
+ value="Euler a",
317
+ )
318
+
319
+ # Seed options
320
+ with gr.Group():
321
+ seed = gr.Slider(
322
+ label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
323
+ )
324
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
325
+
326
+ # Generation parameters
327
+ with gr.Group():
328
+ guidance_scale = gr.Slider(
329
+ label="Guidance scale",
330
+ minimum=1,
331
+ maximum=12,
332
+ step=0.1,
333
+ value=7.0,
334
+ )
335
+ num_inference_steps = gr.Slider(
336
+ label="Number of inference steps",
337
+ minimum=1,
338
+ maximum=50,
339
+ step=1,
340
+ value=28,
341
+ )
342
+
343
+ # Main content area with simplified layout
344
+ with gr.Column():
345
+ # Prompt inputs at the top
346
+ prompt = gr.Text(
347
+ label="Prompt",
348
+ max_lines=5,
349
+ placeholder="Enter your prompt",
350
+ )
351
+ negative_prompt = gr.Text(
352
+ label="Negative Prompt",
353
+ max_lines=5,
354
+ placeholder="Enter a negative prompt",
355
+ )
356
+
357
+ # Result gallery
358
+ result = gr.Gallery(
359
+ label="Result",
360
+ columns=1,
361
+ height='600px',
362
+ preview=True,
363
+ show_label=False
364
+ )
365
+
366
+ # Generate button under the gallery
367
+ run_button = gr.Button("Generate", variant="primary", size="lg")
368
+
369
+ # Generation parameters in accordion
370
+ with gr.Accordion(label="Generation Parameters", open=False):
371
+ gr_metadata = gr.JSON(label="metadata", show_label=False)
372
+
373
+ # Examples
374
+ gr.Examples(
375
+ examples=config.examples,
376
+ inputs=prompt,
377
+ outputs=[result, gr_metadata],
378
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
379
+ cache_examples=CACHE_EXAMPLES,
380
+ )
381
+
382
+ # Event handlers
383
+ use_upscaler.change(
384
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
385
+ inputs=use_upscaler,
386
+ outputs=[upscaler_strength, upscale_by],
387
+ queue=False,
388
+ api_name=False,
389
+ )
390
+
391
+ aspect_ratio_selector.change(
392
+ fn=lambda x: gr.update(visible=x == "Custom"),
393
+ inputs=aspect_ratio_selector,
394
+ outputs=custom_resolution,
395
+ queue=False,
396
+ api_name=False,
397
+ )
398
+
399
+ gr.on(
400
+ triggers=[
401
+ prompt.submit,
402
+ negative_prompt.submit,
403
+ run_button.click,
404
+ ],
405
+ fn=utils.randomize_seed_fn,
406
+ inputs=[seed, randomize_seed],
407
+ outputs=seed,
408
+ queue=False,
409
+ api_name=False,
410
+ ).then(
411
+ fn=generate,
412
+ inputs=[
413
+ prompt,
414
+ negative_prompt,
415
+ seed,
416
+ custom_width,
417
+ custom_height,
418
+ guidance_scale,
419
+ num_inference_steps,
420
+ sampler,
421
+ aspect_ratio_selector,
422
+ style_selector,
423
+ quality_selector,
424
+ use_upscaler,
425
+ upscaler_strength,
426
+ upscale_by,
427
+ add_quality_tags,
428
+ ],
429
+ outputs=[result, gr_metadata],
430
+ api_name="run",
431
+ )
432
+
433
+ if __name__ == "__main__":
434
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)