omer11a commited on
Commit
b322d0b
·
1 Parent(s): 4c60292

Fix examples

Browse files
Files changed (1) hide show
  1. app.py +24 -47
app.py CHANGED
@@ -18,7 +18,7 @@ MIN_SIZE = 0.01
18
  WHITE = 255
19
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
20
 
21
- PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
22
  PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
23
  PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
24
  EXAMPLE_BOXES = {
@@ -146,7 +146,7 @@ def generate(
146
 
147
  filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
148
  num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None
149
- prompts = [prompt.strip('.').strip(',').strip()] * batch_size
150
 
151
  images = inference(
152
  boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
@@ -158,14 +158,14 @@ def generate(
158
 
159
  def convert_token_indices(token_indices, nested=False):
160
  if nested:
161
- return [convert_token_indices(indices, nested=False) for indices in token_indices.split(';')]
162
 
163
- return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
164
 
165
 
166
  def draw(sketchpad):
167
  boxes = []
168
- for i, layer in enumerate(sketchpad['layers']):
169
  non_zeros = layer.nonzero()
170
  x1 = x2 = y1 = y2 = 0
171
  if len(non_zeros[0]) > 0:
@@ -177,7 +177,7 @@ def draw(sketchpad):
177
  y2 = y1y2.max()
178
 
179
  if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
180
- raise gr.Error(f'Box in layer {i} is too small')
181
 
182
  boxes.append((x1, y1, x2, y2))
183
 
@@ -185,15 +185,16 @@ def draw(sketchpad):
185
  return [boxes, layout_image]
186
 
187
 
188
- def draw_boxes(boxes):
189
  if len(boxes) == 0:
190
  return None
191
 
192
  boxes = np.array(boxes) * RESOLUTION
193
- image = Image.new('RGB', (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE))
194
  drawing = ImageDraw.Draw(image)
195
  for i, box in enumerate(boxes.astype(int).tolist()):
196
- drawing.rectangle(box, outline=COLORS[i % len(COLORS)], width=4)
 
197
 
198
  return image
199
 
@@ -202,35 +203,11 @@ def clear(batch_size):
202
  return [[], None, None, None]
203
 
204
 
205
- def generate_example(
206
- prompt,
207
- subject_token_indices,
208
- filter_token_indices,
209
- num_tokens,
210
- init_step_size,
211
- final_step_size,
212
- num_clusters_per_subject,
213
- cross_loss_scale,
214
- self_loss_scale,
215
- classifier_free_guidance_scale,
216
- batch_size,
217
- num_iterations,
218
- loss_threshold,
219
- num_guidance_steps,
220
- seed,
221
- ):
222
- layers = []
223
  boxes = EXAMPLE_BOXES[prompt]
224
- for box in boxes:
225
- layers.append(draw_boxes([box]))
226
-
227
- sketchpad = {'layers': layers}
228
- layout_images = draw_boxes(boxes)
229
- out_images = generate(prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
230
- final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
231
- batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes)
232
-
233
- return boxes, sketchpad, layout_image, out_images
234
 
235
 
236
  def main():
@@ -274,7 +251,7 @@ def main():
274
  }
275
  """
276
 
277
- nltk.download('averaged_perceptron_tagger')
278
 
279
  with gr.Blocks(
280
  css=css,
@@ -301,13 +278,13 @@ def main():
301
  )
302
 
303
  with gr.Row():
304
- sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION)
305
  layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION, scale=1)
306
 
307
  with gr.Row():
308
- clear_button = gr.Button(value='Clear')
309
- generate_layout_button = gr.Button(value='Generate layout')
310
- generate_image_button = gr.Button(value='Generate image')
311
 
312
  with gr.Row():
313
  out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
@@ -392,29 +369,29 @@ def main():
392
  gr.Examples(
393
  examples=[
394
  [
395
- "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest",
396
  "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
397
  25, 10, 3, 1, 1,
398
  7.5, 1, 5, 0.2, 15,
399
  286,
400
  ],
401
  [
402
- "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship",
403
  "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
404
  18, 5, 3, 1, 1,
405
  7.5, 1, 5, 0.2, 15,
406
  216,
407
  ],
408
  [
409
- "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool",
410
  "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
411
  18, 5, 3, 1, 1,
412
  7.5, 1, 5, 0.2, 15,
413
  156,
414
  ],
415
  ],
416
- fn=generate_example,
417
  inputs=[
 
418
  prompt, subject_token_indices, filter_token_indices, num_tokens,
419
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
420
  classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
@@ -427,5 +404,5 @@ def main():
427
 
428
  demo.launch(show_api=False, show_error=True)
429
 
430
- if __name__ == '__main__':
431
  main()
 
18
  WHITE = 255
19
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
20
 
21
+ PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
22
  PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
23
  PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
24
  EXAMPLE_BOXES = {
 
146
 
147
  filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
148
  num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None
149
+ prompts = [prompt.strip(".").strip(",").strip()] * batch_size
150
 
151
  images = inference(
152
  boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
 
158
 
159
  def convert_token_indices(token_indices, nested=False):
160
  if nested:
161
+ return [convert_token_indices(indices, nested=False) for indices in token_indices.split(";")]
162
 
163
+ return [int(index.strip()) for index in token_indices.split(",") if len(index.strip()) > 0]
164
 
165
 
166
  def draw(sketchpad):
167
  boxes = []
168
+ for i, layer in enumerate(sketchpad["layers"]):
169
  non_zeros = layer.nonzero()
170
  x1 = x2 = y1 = y2 = 0
171
  if len(non_zeros[0]) > 0:
 
177
  y2 = y1y2.max()
178
 
179
  if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
180
+ raise gr.Error(f"Box in layer {i} is too small")
181
 
182
  boxes.append((x1, y1, x2, y2))
183
 
 
185
  return [boxes, layout_image]
186
 
187
 
188
+ def draw_boxes(boxes, is_sketch=False):
189
  if len(boxes) == 0:
190
  return None
191
 
192
  boxes = np.array(boxes) * RESOLUTION
193
+ image = Image.new("RGB", (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE))
194
  drawing = ImageDraw.Draw(image)
195
  for i, box in enumerate(boxes.astype(int).tolist()):
196
+ color = "black" if is_sketch else COLORS[i % len(COLORS)]
197
+ drawing.rectangle(box, outline=color, width=4)
198
 
199
  return image
200
 
 
203
  return [[], None, None, None]
204
 
205
 
206
+ def make_example_inputs(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  boxes = EXAMPLE_BOXES[prompt]
208
+ sketchpad = draw_boxes(boxes, is_sketch=True)
209
+ layout_image = draw_boxes(boxes)
210
+ return sketchpad, layout_image, prompt
 
 
 
 
 
 
 
211
 
212
 
213
  def main():
 
251
  }
252
  """
253
 
254
+ nltk.download("averaged_perceptron_tagger")
255
 
256
  with gr.Blocks(
257
  css=css,
 
278
  )
279
 
280
  with gr.Row():
281
+ sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)", width=RESOLUTION, height=RESOLUTION)
282
  layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION, scale=1)
283
 
284
  with gr.Row():
285
+ clear_button = gr.Button(value="Clear")
286
+ generate_layout_button = gr.Button(value="Generate layout")
287
+ generate_image_button = gr.Button(value="Generate image")
288
 
289
  with gr.Row():
290
  out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
 
369
  gr.Examples(
370
  examples=[
371
  [
372
+ *make_example_inputs(PROMPT1),
373
  "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
374
  25, 10, 3, 1, 1,
375
  7.5, 1, 5, 0.2, 15,
376
  286,
377
  ],
378
  [
379
+ *make_example_inputs(PROMPT2),
380
  "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
381
  18, 5, 3, 1, 1,
382
  7.5, 1, 5, 0.2, 15,
383
  216,
384
  ],
385
  [
386
+ *make_example_inputs(PROMPT3),
387
  "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
388
  18, 5, 3, 1, 1,
389
  7.5, 1, 5, 0.2, 15,
390
  156,
391
  ],
392
  ],
 
393
  inputs=[
394
+ sketchpad, layout_image,
395
  prompt, subject_token_indices, filter_token_indices, num_tokens,
396
  init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
397
  classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
 
404
 
405
  demo.launch(show_api=False, show_error=True)
406
 
407
+ if __name__ == "__main__":
408
  main()