omer11a commited on
Commit
0b407fa
·
1 Parent(s): 8fea73b

Added examples

Browse files
Files changed (1) hide show
  1. app.py +95 -20
app.py CHANGED
@@ -19,6 +19,31 @@ WHITE = 255
19
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def inference(
23
  boxes,
24
  prompts,
@@ -90,7 +115,7 @@ def generate(
90
  loss_threshold,
91
  num_guidance_steps,
92
  seed,
93
- boxes
94
  ):
95
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
96
  if len(boxes) != len(subject_token_indices):
@@ -157,6 +182,37 @@ def clear(batch_size):
157
  return [[], None, None, None]
158
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def main():
161
  css = """
162
  #paper-info a {
@@ -279,9 +335,9 @@ def main():
279
  </div>
280
  """
281
  gr.HTML(description)
282
- batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (currently limited to one sample)")
283
- init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=25, label="Initial step size")
284
- final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=10, label="Final step size")
285
  num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
286
  cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
287
  self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
@@ -320,22 +376,41 @@ def main():
320
  queue=True,
321
  )
322
 
323
- #with gr.Column():
324
- # gr.Examples(
325
- # examples=[
326
- # [
327
- # [[0.35, 0.4, 0.65, 0.9], [0, 0.6, 0.3, 0.9], [0.7, 0.55, 1, 0.85]],
328
- # "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest",
329
- # "7,8,17;11,12,17;15,16,17",
330
- # "5,6,9,10,13,14,18,19",
331
- # 286,
332
- # ],
333
- # ],
334
- # inputs=[boxes, prompt, subject_token_indices, filter_token_indices, seed],
335
- # outputs=None,
336
- # fn=None,
337
- # cache_examples=False,
338
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  description = """<p> The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>"""
340
  gr.HTML(description)
341
 
 
19
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
20
 
21
 
22
+ PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
23
+ PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
24
+ PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
25
+ EXAMPLE_BOXES = {
26
+ PROMPT1 : [
27
+ [0.35, 0.4, 0.65, 0.9],
28
+ [0, 0.6, 0.3, 0.9],
29
+ [0.7, 0.55, 1, 0.85]
30
+ ],
31
+ PROMPT2: [
32
+ [0.4, 0.45, 0.6, 0.95],
33
+ [0.2, 0.3, 0.4, 0.85],
34
+ [0.6, 0.3, 0.8, 0.85],
35
+ [0.1, 0, 0.9, 0.3]
36
+ ],
37
+ PROMPT3: [
38
+ [0, 0.5, 0.2, 0.8],
39
+ [0.2, 0.2, 0.4, 0.5],
40
+ [0.4, 0.5, 0.6, 0.8],
41
+ [0.6, 0.2, 0.8, 0.5],
42
+ [0.8, 0.5, 1, 0.8]
43
+ ],
44
+ }
45
+
46
+
47
  def inference(
48
  boxes,
49
  prompts,
 
115
  loss_threshold,
116
  num_guidance_steps,
117
  seed,
118
+ boxes,
119
  ):
120
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
121
  if len(boxes) != len(subject_token_indices):
 
182
  return [[], None, None, None]
183
 
184
 
185
+ def generate_example(
186
+ prompt,
187
+ subject_token_indices,
188
+ filter_token_indices,
189
+ num_tokens,
190
+ init_step_size,
191
+ final_step_size,
192
+ num_clusters_per_subject,
193
+ cross_loss_scale,
194
+ self_loss_scale,
195
+ classifier_free_guidance_scale,
196
+ batch_size,
197
+ num_iterations,
198
+ loss_threshold,
199
+ num_guidance_steps,
200
+ seed,
201
+ ):
202
+ layers = []
203
+ boxes = EXAMPLE_BOXES[prompt]
204
+ for box in boxes:
205
+ layers.append(draw_boxes([box]))
206
+
207
+ sketchpad = {'layers': layers}
208
+ layout_images = draw_boxes(boxes)
209
+ out_images = generate(prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
210
+ final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
211
+ batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes)
212
+
213
+ return boxes, sketchpad, layout_image, out_images
214
+
215
+
216
  def main():
217
  css = """
218
  #paper-info a {
 
335
  </div>
336
  """
337
  gr.HTML(description)
338
+ batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)")
339
+ init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=18, label="Initial step size")
340
+ final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=5, label="Final step size")
341
  num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
342
  cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
343
  self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
 
376
  queue=True,
377
  )
378
 
379
+ with gr.Column():
380
+ gr.Examples(
381
+ examples=[
382
+ [
383
+ "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest",
384
+ "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
385
+ 25, 10, 3, 1, 1,
386
+ 7.5, 1, 5, 0.2, 15,
387
+ 286,
388
+ ],
389
+ [
390
+ "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship",
391
+ "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
392
+ 18, 5, 3, 1, 1,
393
+ 7.5, 1, 5, 0.2, 15,
394
+ 216,
395
+ ],
396
+ [
397
+ "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool",
398
+ "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
399
+ 18, 5, 3, 1, 1,
400
+ 7.5, 1, 5, 0.2, 15,
401
+ 156,
402
+ ],
403
+ ],
404
+ fn=generate_example,
405
+ inputs=[
406
+ prompt, subject_token_indices, filter_token_indices, num_tokens,
407
+ init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
408
+ classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
409
+ seed,
410
+ ],
411
+ outputs=[boxes, sketchpad, layout_image, out_images],
412
+ cache_examples=True,
413
+ )
414
  description = """<p> The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>"""
415
  gr.HTML(description)
416