omer11a commited on
Commit
49a7542
·
1 Parent(s): f0d244c

Decreased runtime

Browse files
Files changed (1) hide show
  1. app.py +141 -139
app.py CHANGED
@@ -35,11 +35,50 @@ COPY_LINK = """
35
  </a>
36
  Duplicate this space to generate more samples without waiting in queue
37
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  FOOTNOTE = """
39
  <p>The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>
40
  """
41
 
42
 
 
 
 
 
 
 
 
43
  def inference(
44
  boxes,
45
  prompts,
@@ -61,11 +100,7 @@ def inference(
61
  raise gr.Error("cuda is not available")
62
 
63
  device = torch.device("cuda")
64
- model_path = "stabilityai/stable-diffusion-xl-base-1.0"
65
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
66
- model = StableDiffusionXLPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16).to(device)
67
- model.unet.set_default_attn_processor()
68
- model.enable_sequential_cpu_offload()
69
 
70
  seed_everything(seed)
71
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
@@ -89,12 +124,14 @@ def inference(
89
  num_clusters_per_box=num_clusters_per_subject,
90
  )
91
 
92
- regiter_attention_editor_diffusers(model, editor)
93
 
94
- return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
 
 
95
 
96
 
97
- @spaces.GPU(duration=500)
98
  def generate(
99
  prompt,
100
  subject_token_indices,
@@ -220,134 +257,99 @@ def main():
220
  }
221
  """
222
 
223
- nltk.download("averaged_perceptron_tagger")
224
-
225
- with gr.Blocks(
226
- css=css,
227
- title="Bounded Attention demo",
228
- ) as demo:
229
- gr.HTML(DESCRIPTION)
230
- gr.HTML(COPY_LINK)
231
-
232
- with gr.Column():
233
- gr.HTML("Scroll down to see examples of the required input format.")
234
-
235
- prompt = gr.Textbox(
236
- label="Text prompt",
237
- )
238
-
239
- subject_token_indices = gr.Textbox(
240
- label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)",
241
- )
242
-
243
- filter_token_indices = gr.Textbox(
244
- label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)",
245
- )
246
-
247
- num_tokens = gr.Textbox(
248
- label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)",
249
- )
250
-
251
- with gr.Row():
252
- sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)")
253
- layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False)
254
-
255
- with gr.Row():
256
- clear_button = gr.Button(value="Clear")
257
- generate_layout_button = gr.Button(value="Generate layout")
258
- generate_image_button = gr.Button(value="Generate image")
259
-
260
- with gr.Row():
261
- out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
262
-
263
- with gr.Accordion("Advanced Options", open=False):
264
- with gr.Column():
265
- description = """
266
- <div class="tooltip">Batch size &#9432
267
- <span class="tooltiptext">The number of images to generate.</span>
268
- </div>
269
- <div class="tooltip">Initial step size &#9432
270
- <span class="tooltiptext">The initial step size of the linear step size scheduler when performing guidance.</span>
271
- </div>
272
- <div class="tooltip">Final step size &#9432
273
- <span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
274
- </div>
275
- <div class="tooltip">Number of self-attention clusters per subject &#9432
276
- <span class="tooltiptext">Determines the number of clusters when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span>
277
- </div>
278
- <div class="tooltip">Cross-attention loss scale factor &#9432
279
- <span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
280
- </div>
281
- <div class="tooltip">Self-attention loss scale factor &#9432
282
- <span class="tooltiptext">The scale factor of the self-attention loss term. Increasing it will improve layout control (adherence to the bounding boxes), but may reduce image quality.</span>
283
- </div>
284
- <div class="tooltip">Classifier-free guidance scale &#9432
285
- <span class="tooltiptext">The scale factor of classifier-free guidance.</span>
286
- </div>
287
- <div class="tooltip" >Number of Gradient Descent iterations per timestep &#9432
288
- <span class="tooltiptext">The number of Gradient Descent iterations for each timestep when performing guidance.</span>
289
- </div>
290
- <div class="tooltip" >Loss Threshold &#9432
291
- <span class="tooltiptext">If the loss is below the threshold, Gradient Descent stops for that timestep. </span>
292
- </div>
293
- <div class="tooltip" >Number of guidance steps &#9432
294
- <span class="tooltiptext">The number of timesteps in which to perform guidance.</span>
295
- </div>
296
- """
297
- gr.HTML(description)
298
- batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)")
299
- init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=18, label="Initial step size")
300
- final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=5, label="Final step size")
301
- num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
302
- cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
303
- self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
304
- classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale")
305
- num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations")
306
- loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold")
307
- num_guidance_steps = gr.Slider(minimum=10, maximum=20, step=1, value=15, label="Number of timesteps to perform guidance")
308
- seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
309
-
310
- boxes = gr.State([])
311
-
312
- clear_button.click(
313
- clear,
314
- inputs=[batch_size],
315
- outputs=[boxes, sketchpad, layout_image, out_images],
316
- queue=False,
317
- )
318
-
319
- generate_layout_button.click(
320
- draw,
321
- inputs=[sketchpad],
322
- outputs=[boxes, layout_image],
323
- queue=False,
324
- )
325
-
326
- generate_image_button.click(
327
- fn=generate,
328
- inputs=[
329
- prompt, subject_token_indices, filter_token_indices, num_tokens,
330
- init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
331
- classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
332
- seed,
333
- boxes,
334
- ],
335
- outputs=[out_images],
336
- queue=True,
337
- )
338
-
339
- with gr.Column():
340
- gr.Examples(
341
- examples=[
342
- ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
343
- ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
344
- ],
345
- inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
346
- )
347
-
348
- gr.HTML(FOOTNOTE)
349
-
350
- demo.launch(show_api=False, show_error=True)
351
-
352
- if __name__ == "__main__":
353
- main()
 
35
  </a>
36
  Duplicate this space to generate more samples without waiting in queue
37
  """
38
+ ADVANCED_OPTION_DESCRIPTION = """
39
+ <div class="tooltip" >Number of guidance steps &#9432
40
+ <span class="tooltiptext">The number of timesteps in which to perform guidance. Recommended value is 15, but increasing this will also increases the runtime.</span>
41
+ </div>
42
+ <div class="tooltip">Batch size &#9432
43
+ <span class="tooltiptext">The number of images to generate.</span>
44
+ </div>
45
+ <div class="tooltip">Initial step size &#9432
46
+ <span class="tooltiptext">The initial step size of the linear step size scheduler when performing guidance.</span>
47
+ </div>
48
+ <div class="tooltip">Final step size &#9432
49
+ <span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
50
+ </div>
51
+ <div class="tooltip">Number of self-attention clusters per subject &#9432
52
+ <span class="tooltiptext">Determines the number of clusters when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span>
53
+ </div>
54
+ <div class="tooltip">Cross-attention loss scale factor &#9432
55
+ <span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
56
+ </div>
57
+ <div class="tooltip">Self-attention loss scale factor &#9432
58
+ <span class="tooltiptext">The scale factor of the self-attention loss term. Increasing it will improve layout control (adherence to the bounding boxes), but may reduce image quality.</span>
59
+ </div>
60
+ <div class="tooltip" >Number of Gradient Descent iterations per timestep &#9432
61
+ <span class="tooltiptext">The number of Gradient Descent iterations for each timestep when performing guidance.</span>
62
+ </div>
63
+ <div class="tooltip" >Loss Threshold &#9432
64
+ <span class="tooltiptext">If the loss is below the threshold, Gradient Descent stops for that timestep. </span>
65
+ </div>
66
+ <div class="tooltip">Classifier-free guidance scale &#9432
67
+ <span class="tooltiptext">The scale factor of classifier-free guidance.</span>
68
+ </div>
69
+ """
70
  FOOTNOTE = """
71
  <p>The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>
72
  """
73
 
74
 
75
+ MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
76
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
77
+ model = StableDiffusionXLPipeline.from_pretrained(MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16)
78
+ model.unet.set_default_attn_processor()
79
+ model.enable_sequential_cpu_offload()
80
+
81
+
82
  def inference(
83
  boxes,
84
  prompts,
 
100
  raise gr.Error("cuda is not available")
101
 
102
  device = torch.device("cuda")
103
+ model = model.to(device)
 
 
 
 
104
 
105
  seed_everything(seed)
106
  start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
 
124
  num_clusters_per_box=num_clusters_per_subject,
125
  )
126
 
127
+ register_attention_editor_diffusers(model, editor)
128
 
129
+ images = model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
130
+ unregister_attention_editor_diffusers(model)
131
+ model.to(torch.device("cpu"))
132
 
133
 
134
+ @spaces.GPU(duration=300)
135
  def generate(
136
  prompt,
137
  subject_token_indices,
 
257
  }
258
  """
259
 
260
+ nltk.download("averaged_perceptron_tagger")
261
+
262
+ with gr.Blocks(
263
+ css=css,
264
+ title="Bounded Attention demo",
265
+ ) as demo:
266
+ gr.HTML(DESCRIPTION)
267
+ gr.HTML(COPY_LINK)
268
+
269
+ with gr.Column():
270
+ gr.HTML("Scroll down to see examples of the required input format.")
271
+
272
+ prompt = gr.Textbox(
273
+ label="Text prompt",
274
+ )
275
+
276
+ subject_token_indices = gr.Textbox(
277
+ label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)",
278
+ )
279
+
280
+ filter_token_indices = gr.Textbox(
281
+ label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)",
282
+ )
283
+
284
+ num_tokens = gr.Textbox(
285
+ label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)",
286
+ )
287
+
288
+ with gr.Row():
289
+ sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)")
290
+ layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False)
291
+
292
+ with gr.Row():
293
+ clear_button = gr.Button(value="Clear")
294
+ generate_layout_button = gr.Button(value="Generate layout")
295
+ generate_image_button = gr.Button(value="Generate image")
296
+
297
+ with gr.Row():
298
+ out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
299
+
300
+ with gr.Accordion("Advanced Options", open=False):
301
+ with gr.Column():
302
+ gr.HTML(ADVANCED_OPTION_DESCRIPTION)
303
+ batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)")
304
+ num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
305
+ init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=25, label="Initial step size")
306
+ final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=10, label="Final step size")
307
+ num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
308
+ cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
309
+ self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
310
+ num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations")
311
+ loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold")
312
+ classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale")
313
+ seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed")
314
+
315
+ boxes = gr.State([])
316
+
317
+ clear_button.click(
318
+ clear,
319
+ inputs=[batch_size],
320
+ outputs=[boxes, sketchpad, layout_image, out_images],
321
+ queue=False,
322
+ )
323
+
324
+ generate_layout_button.click(
325
+ draw,
326
+ inputs=[sketchpad],
327
+ outputs=[boxes, layout_image],
328
+ queue=False,
329
+ )
330
+
331
+ generate_image_button.click(
332
+ fn=generate,
333
+ inputs=[
334
+ prompt, subject_token_indices, filter_token_indices, num_tokens,
335
+ init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
336
+ classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
337
+ seed,
338
+ boxes,
339
+ ],
340
+ outputs=[out_images],
341
+ queue=True,
342
+ )
343
+
344
+ with gr.Column():
345
+ gr.Examples(
346
+ examples=[
347
+ ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
348
+ ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
349
+ ],
350
+ inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
351
+ )
352
+
353
+ gr.HTML(FOOTNOTE)
354
+
355
+ demo.launch(show_api=False, show_error=True)