omer11a commited on
Commit
72df8c8
·
1 Parent(s): 9d8ec57

Modified examples

Browse files
Files changed (2) hide show
  1. app.py +79 -7
  2. bounded_attention.py +18 -16
app.py CHANGED
@@ -19,6 +19,30 @@ MIN_SIZE = 0.01
19
  WHITE = 255
20
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  CSS = """
23
  #paper-info a {
24
  color:#008AD7;
@@ -88,8 +112,11 @@ ADVANCED_OPTION_DESCRIPTION = """
88
  <div class="tooltip">Final step size &#9432
89
  <span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
90
  </div>
 
 
 
91
  <div class="tooltip">Number of self-attention clusters per subject &#9432
92
- <span class="tooltiptext">Determines the number of clusters when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span>
93
  </div>
94
  <div class="tooltip">Cross-attention loss scale factor &#9432
95
  <span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
@@ -120,6 +147,7 @@ def inference(
120
  num_tokens,
121
  init_step_size,
122
  final_step_size,
 
123
  num_clusters_per_subject,
124
  cross_loss_scale,
125
  self_loss_scale,
@@ -158,6 +186,7 @@ def inference(
158
  start_step_size=init_step_size,
159
  end_step_size=final_step_size,
160
  loss_stopping_value=loss_threshold,
 
161
  num_clusters_per_box=num_clusters_per_subject,
162
  max_resolution=32,
163
  )
@@ -174,6 +203,7 @@ def generate(
174
  num_tokens,
175
  init_step_size,
176
  final_step_size,
 
177
  num_clusters_per_subject,
178
  cross_loss_scale,
179
  self_loss_scale,
@@ -185,6 +215,7 @@ def generate(
185
  seed,
186
  boxes,
187
  ):
 
188
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
189
  if len(boxes) != len(subject_token_indices):
190
  raise gr.Error("""
@@ -198,8 +229,8 @@ def generate(
198
 
199
  images = inference(
200
  boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
201
- final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
202
- num_iterations, loss_threshold, num_guidance_steps, seed)
203
 
204
  return images
205
 
@@ -251,6 +282,17 @@ def clear(batch_size):
251
  return [[], None, None, None]
252
 
253
 
 
 
 
 
 
 
 
 
 
 
 
254
  def main():
255
  nltk.download("averaged_perceptron_tagger")
256
 
@@ -303,6 +345,7 @@ def main():
303
  num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
304
  init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size")
305
  final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size")
 
306
  num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
307
  cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
308
  self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
@@ -331,7 +374,7 @@ def main():
331
  fn=generate,
332
  inputs=[
333
  prompt, subject_token_indices, filter_token_indices, num_tokens,
334
- init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
335
  classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
336
  seed,
337
  boxes,
@@ -342,11 +385,40 @@ def main():
342
 
343
  with gr.Column():
344
  gr.Examples(
 
 
 
 
 
345
  examples=[
346
- ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
347
- ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  ],
349
- inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
 
350
  )
351
 
352
  gr.HTML(FOOTNOTE)
 
19
  WHITE = 255
20
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
21
 
22
+ PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
23
+ PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
24
+ PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
25
+ EXAMPLE_BOXES = {
26
+ PROMPT1 : [
27
+ [0.35, 0.4, 0.65, 0.9],
28
+ [0, 0.6, 0.3, 0.9],
29
+ [0.7, 0.55, 1, 0.85]
30
+ ],
31
+ PROMPT2: [
32
+ [0.4, 0.45, 0.6, 0.95],
33
+ [0.2, 0.3, 0.4, 0.85],
34
+ [0.6, 0.3, 0.8, 0.85],
35
+ [0.1, 0, 0.9, 0.3]
36
+ ],
37
+ PROMPT3: [
38
+ [0, 0.5, 0.2, 0.8],
39
+ [0.2, 0.2, 0.4, 0.5],
40
+ [0.4, 0.5, 0.6, 0.8],
41
+ [0.6, 0.2, 0.8, 0.5],
42
+ [0.8, 0.5, 1, 0.8]
43
+ ],
44
+ }
45
+
46
  CSS = """
47
  #paper-info a {
48
  color:#008AD7;
 
112
  <div class="tooltip">Final step size &#9432
113
  <span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
114
  </div>
115
+ <div class="tooltip">First refinement step &#9432
116
+ <span class="tooltiptext">The timestep from which subject mask refinement is performed.</span>
117
+ </div>
118
  <div class="tooltip">Number of self-attention clusters per subject &#9432
119
+ <span class="tooltiptext">The number of clusters computed when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span>
120
  </div>
121
  <div class="tooltip">Cross-attention loss scale factor &#9432
122
  <span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
 
147
  num_tokens,
148
  init_step_size,
149
  final_step_size,
150
+ first_refinement_step,
151
  num_clusters_per_subject,
152
  cross_loss_scale,
153
  self_loss_scale,
 
186
  start_step_size=init_step_size,
187
  end_step_size=final_step_size,
188
  loss_stopping_value=loss_threshold,
189
+ min_clustering_step=first_refinement_step,
190
  num_clusters_per_box=num_clusters_per_subject,
191
  max_resolution=32,
192
  )
 
203
  num_tokens,
204
  init_step_size,
205
  final_step_size,
206
+ first_refinement_step,
207
  num_clusters_per_subject,
208
  cross_loss_scale,
209
  self_loss_scale,
 
215
  seed,
216
  boxes,
217
  ):
218
+ print('boxes in generate', boxes)
219
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
220
  if len(boxes) != len(subject_token_indices):
221
  raise gr.Error("""
 
229
 
230
  images = inference(
231
  boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
232
+ final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
233
+ classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed)
234
 
235
  return images
236
 
 
282
  return [[], None, None, None]
283
 
284
 
285
+ def build_example_layout(prompt, *args):
286
+ boxes = EXAMPLE_BOXES[prompt]
287
+
288
+ composite = draw_boxes(boxes, is_sketch=True)
289
+ sketchpad = {"background": None, "layers": [], "composite": composite}
290
+
291
+ layout_image = draw_boxes(boxes)
292
+
293
+ return boxes, sketchpad, layout_image
294
+
295
+
296
  def main():
297
  nltk.download("averaged_perceptron_tagger")
298
 
 
345
  num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
346
  init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size")
347
  final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size")
348
+ first_refinement_step = gr.Slider(minimum=0, maximum=50, step=1, value=15, label="The timestep from which to start refining the subject masks")
349
  num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
350
  cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
351
  self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
 
374
  fn=generate,
375
  inputs=[
376
  prompt, subject_token_indices, filter_token_indices, num_tokens,
377
+ init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
378
  classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
379
  seed,
380
  boxes,
 
385
 
386
  with gr.Column():
387
  gr.Examples(
388
+ #examples=[
389
+ # ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
390
+ # ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
391
+ #],
392
+ #inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
393
  examples=[
394
+ [
395
+ PROMPT1, "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
396
+ 25, 18, 3, 1, 1,
397
+ 7.5, 1, 5, 0.2, 8,
398
+ 286,
399
+ ],
400
+ [
401
+ PROMPT2, "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
402
+ 18, 12, 3, 1, 1,
403
+ 7.5, 1, 5, 0.2, 8,
404
+ 216,
405
+ ],
406
+ [
407
+ PROMPT3, "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
408
+ 18, 12, 3, 1, 1,
409
+ 7.5, 1, 5, 0.2, 8,
410
+ 156,
411
+ ],
412
+ ],
413
+ fn=build_example_layout,
414
+ inputs=[
415
+ prompt, subject_token_indices, filter_token_indices, num_tokens,
416
+ init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
417
+ classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
418
+ seed,
419
  ],
420
+ outputs=[boxes, sketchpad, layout_image],
421
+ run_on_click=True,
422
  )
423
 
424
  gr.HTML(FOOTNOTE)
bounded_attention.py CHANGED
@@ -38,6 +38,7 @@ class BoundedAttention(injection_utils.AttentionBase):
38
  start_step_size=30,
39
  end_step_size=10,
40
  loss_stopping_value=0.2,
 
41
  cross_mask_threshold=0.2,
42
  self_mask_threshold=0.2,
43
  delta_refine_mask_steps=5,
@@ -73,6 +74,7 @@ class BoundedAttention(injection_utils.AttentionBase):
73
  self.start_step_size = start_step_size
74
  self.step_size_coef = (end_step_size - start_step_size) / max_guidance_iter
75
  self.loss_stopping_value = loss_stopping_value
 
76
  self.cross_mask_threshold = cross_mask_threshold
77
  self.self_mask_threshold = self_mask_threshold
78
 
@@ -367,7 +369,7 @@ class BoundedAttention(injection_utils.AttentionBase):
367
 
368
  def _obtain_masks(self, resolution, return_boxes=False, return_existing=False, batch_size=None, device=None):
369
  return_boxes = return_boxes or (return_existing and self.self_masks is None)
370
- if return_boxes or self.cur_step < self.max_guidance_iter:
371
  masks = self._convert_boxes_to_masks(resolution, device=device).unsqueeze(0)
372
  if batch_size is not None:
373
  masks = masks.expand(batch_size, *masks.shape[1:])
@@ -406,21 +408,6 @@ class BoundedAttention(injection_utils.AttentionBase):
406
  self_masks = F.interpolate(self_masks, resolution, mode='nearest-exact')
407
  return self_masks.flatten(start_dim=2).bool()
408
 
409
- def _cluster_self_maps(self): # b s n
410
- self_maps = self._compute_maps(self.mean_self_map) # b n m
411
- if self.pca_rank is not None:
412
- dtype = self_maps.dtype
413
- _, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
414
- self_maps = torch.matmul(self_maps, eigen_vectors.to(dtype=dtype))
415
-
416
- clustering_results = self.clustering(self_maps, centers=self.centers)
417
- self.clustering.num_init = 1 # clustering is deterministic after the first time
418
- self.centers = clustering_results.centers
419
- clusters = clustering_results.labels
420
- num_clusters = self.clustering.n_clusters
421
- self._save_maps(clusters / num_clusters, f'clusters')
422
- return num_clusters, clusters
423
-
424
  def _build_self_masks(self):
425
  c, clusters = self._cluster_self_maps() # b n
426
  cluster_masks = torch.stack([(clusters == cluster_index) for cluster_index in range(c)], dim=2) # b n c
@@ -444,6 +431,21 @@ class BoundedAttention(injection_utils.AttentionBase):
444
  self._save_maps(self_masks, 'self_masks')
445
  return self_masks
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  def _obtain_cross_masks(self, resolution, scale=10):
448
  maps = self._compute_maps(self.mean_cross_map, resolution=resolution) # b n k
449
  maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
 
38
  start_step_size=30,
39
  end_step_size=10,
40
  loss_stopping_value=0.2,
41
+ min_clustering_step=15,
42
  cross_mask_threshold=0.2,
43
  self_mask_threshold=0.2,
44
  delta_refine_mask_steps=5,
 
74
  self.start_step_size = start_step_size
75
  self.step_size_coef = (end_step_size - start_step_size) / max_guidance_iter
76
  self.loss_stopping_value = loss_stopping_value
77
+ self.min_clustering_step = min_clustering_step
78
  self.cross_mask_threshold = cross_mask_threshold
79
  self.self_mask_threshold = self_mask_threshold
80
 
 
369
 
370
  def _obtain_masks(self, resolution, return_boxes=False, return_existing=False, batch_size=None, device=None):
371
  return_boxes = return_boxes or (return_existing and self.self_masks is None)
372
+ if return_boxes or self.cur_step < self.min_clustering_step:
373
  masks = self._convert_boxes_to_masks(resolution, device=device).unsqueeze(0)
374
  if batch_size is not None:
375
  masks = masks.expand(batch_size, *masks.shape[1:])
 
408
  self_masks = F.interpolate(self_masks, resolution, mode='nearest-exact')
409
  return self_masks.flatten(start_dim=2).bool()
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  def _build_self_masks(self):
412
  c, clusters = self._cluster_self_maps() # b n
413
  cluster_masks = torch.stack([(clusters == cluster_index) for cluster_index in range(c)], dim=2) # b n c
 
431
  self._save_maps(self_masks, 'self_masks')
432
  return self_masks
433
 
434
+ def _cluster_self_maps(self): # b s n
435
+ self_maps = self._compute_maps(self.mean_self_map) # b n m
436
+ if self.pca_rank is not None:
437
+ dtype = self_maps.dtype
438
+ _, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
439
+ self_maps = torch.matmul(self_maps, eigen_vectors.to(dtype=dtype))
440
+
441
+ clustering_results = self.clustering(self_maps, centers=self.centers)
442
+ self.clustering.num_init = 1 # clustering is deterministic after the first time
443
+ self.centers = clustering_results.centers
444
+ clusters = clustering_results.labels
445
+ num_clusters = self.clustering.n_clusters
446
+ self._save_maps(clusters / num_clusters, f'clusters')
447
+ return num_clusters, clusters
448
+
449
  def _obtain_cross_masks(self, resolution, scale=10):
450
  maps = self._compute_maps(self.mean_cross_map, resolution=resolution) # b n k
451
  maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))