Spaces:
Runtime error
Runtime error
Modified examples
Browse files- app.py +79 -7
- bounded_attention.py +18 -16
app.py
CHANGED
@@ -19,6 +19,30 @@ MIN_SIZE = 0.01
|
|
19 |
WHITE = 255
|
20 |
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
CSS = """
|
23 |
#paper-info a {
|
24 |
color:#008AD7;
|
@@ -88,8 +112,11 @@ ADVANCED_OPTION_DESCRIPTION = """
|
|
88 |
<div class="tooltip">Final step size ⓘ
|
89 |
<span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
|
90 |
</div>
|
|
|
|
|
|
|
91 |
<div class="tooltip">Number of self-attention clusters per subject ⓘ
|
92 |
-
<span class="tooltiptext">
|
93 |
</div>
|
94 |
<div class="tooltip">Cross-attention loss scale factor ⓘ
|
95 |
<span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
|
@@ -120,6 +147,7 @@ def inference(
|
|
120 |
num_tokens,
|
121 |
init_step_size,
|
122 |
final_step_size,
|
|
|
123 |
num_clusters_per_subject,
|
124 |
cross_loss_scale,
|
125 |
self_loss_scale,
|
@@ -158,6 +186,7 @@ def inference(
|
|
158 |
start_step_size=init_step_size,
|
159 |
end_step_size=final_step_size,
|
160 |
loss_stopping_value=loss_threshold,
|
|
|
161 |
num_clusters_per_box=num_clusters_per_subject,
|
162 |
max_resolution=32,
|
163 |
)
|
@@ -174,6 +203,7 @@ def generate(
|
|
174 |
num_tokens,
|
175 |
init_step_size,
|
176 |
final_step_size,
|
|
|
177 |
num_clusters_per_subject,
|
178 |
cross_loss_scale,
|
179 |
self_loss_scale,
|
@@ -185,6 +215,7 @@ def generate(
|
|
185 |
seed,
|
186 |
boxes,
|
187 |
):
|
|
|
188 |
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
|
189 |
if len(boxes) != len(subject_token_indices):
|
190 |
raise gr.Error("""
|
@@ -198,8 +229,8 @@ def generate(
|
|
198 |
|
199 |
images = inference(
|
200 |
boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
201 |
-
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
202 |
-
num_iterations, loss_threshold, num_guidance_steps, seed)
|
203 |
|
204 |
return images
|
205 |
|
@@ -251,6 +282,17 @@ def clear(batch_size):
|
|
251 |
return [[], None, None, None]
|
252 |
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
def main():
|
255 |
nltk.download("averaged_perceptron_tagger")
|
256 |
|
@@ -303,6 +345,7 @@ def main():
|
|
303 |
num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
|
304 |
init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size")
|
305 |
final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size")
|
|
|
306 |
num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
|
307 |
cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
|
308 |
self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
|
@@ -331,7 +374,7 @@ def main():
|
|
331 |
fn=generate,
|
332 |
inputs=[
|
333 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
334 |
-
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
335 |
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
|
336 |
seed,
|
337 |
boxes,
|
@@ -342,11 +385,40 @@ def main():
|
|
342 |
|
343 |
with gr.Column():
|
344 |
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
345 |
examples=[
|
346 |
-
[
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
],
|
349 |
-
|
|
|
350 |
)
|
351 |
|
352 |
gr.HTML(FOOTNOTE)
|
|
|
19 |
WHITE = 255
|
20 |
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
21 |
|
22 |
+
PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
|
23 |
+
PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
|
24 |
+
PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
|
25 |
+
EXAMPLE_BOXES = {
|
26 |
+
PROMPT1 : [
|
27 |
+
[0.35, 0.4, 0.65, 0.9],
|
28 |
+
[0, 0.6, 0.3, 0.9],
|
29 |
+
[0.7, 0.55, 1, 0.85]
|
30 |
+
],
|
31 |
+
PROMPT2: [
|
32 |
+
[0.4, 0.45, 0.6, 0.95],
|
33 |
+
[0.2, 0.3, 0.4, 0.85],
|
34 |
+
[0.6, 0.3, 0.8, 0.85],
|
35 |
+
[0.1, 0, 0.9, 0.3]
|
36 |
+
],
|
37 |
+
PROMPT3: [
|
38 |
+
[0, 0.5, 0.2, 0.8],
|
39 |
+
[0.2, 0.2, 0.4, 0.5],
|
40 |
+
[0.4, 0.5, 0.6, 0.8],
|
41 |
+
[0.6, 0.2, 0.8, 0.5],
|
42 |
+
[0.8, 0.5, 1, 0.8]
|
43 |
+
],
|
44 |
+
}
|
45 |
+
|
46 |
CSS = """
|
47 |
#paper-info a {
|
48 |
color:#008AD7;
|
|
|
112 |
<div class="tooltip">Final step size ⓘ
|
113 |
<span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span>
|
114 |
</div>
|
115 |
+
<div class="tooltip">First refinement step ⓘ
|
116 |
+
<span class="tooltiptext">The timestep from which subject mask refinement is performed.</span>
|
117 |
+
</div>
|
118 |
<div class="tooltip">Number of self-attention clusters per subject ⓘ
|
119 |
+
<span class="tooltiptext">The number of clusters computed when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span>
|
120 |
</div>
|
121 |
<div class="tooltip">Cross-attention loss scale factor ⓘ
|
122 |
<span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span>
|
|
|
147 |
num_tokens,
|
148 |
init_step_size,
|
149 |
final_step_size,
|
150 |
+
first_refinement_step,
|
151 |
num_clusters_per_subject,
|
152 |
cross_loss_scale,
|
153 |
self_loss_scale,
|
|
|
186 |
start_step_size=init_step_size,
|
187 |
end_step_size=final_step_size,
|
188 |
loss_stopping_value=loss_threshold,
|
189 |
+
min_clustering_step=first_refinement_step,
|
190 |
num_clusters_per_box=num_clusters_per_subject,
|
191 |
max_resolution=32,
|
192 |
)
|
|
|
203 |
num_tokens,
|
204 |
init_step_size,
|
205 |
final_step_size,
|
206 |
+
first_refinement_step,
|
207 |
num_clusters_per_subject,
|
208 |
cross_loss_scale,
|
209 |
self_loss_scale,
|
|
|
215 |
seed,
|
216 |
boxes,
|
217 |
):
|
218 |
+
print('boxes in generate', boxes)
|
219 |
subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
|
220 |
if len(boxes) != len(subject_token_indices):
|
221 |
raise gr.Error("""
|
|
|
229 |
|
230 |
images = inference(
|
231 |
boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
232 |
+
final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
233 |
+
classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed)
|
234 |
|
235 |
return images
|
236 |
|
|
|
282 |
return [[], None, None, None]
|
283 |
|
284 |
|
285 |
+
def build_example_layout(prompt, *args):
|
286 |
+
boxes = EXAMPLE_BOXES[prompt]
|
287 |
+
|
288 |
+
composite = draw_boxes(boxes, is_sketch=True)
|
289 |
+
sketchpad = {"background": None, "layers": [], "composite": composite}
|
290 |
+
|
291 |
+
layout_image = draw_boxes(boxes)
|
292 |
+
|
293 |
+
return boxes, sketchpad, layout_image
|
294 |
+
|
295 |
+
|
296 |
def main():
|
297 |
nltk.download("averaged_perceptron_tagger")
|
298 |
|
|
|
345 |
num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance")
|
346 |
init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size")
|
347 |
final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size")
|
348 |
+
first_refinement_step = gr.Slider(minimum=0, maximum=50, step=1, value=15, label="The timestep from which to start refining the subject masks")
|
349 |
num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
|
350 |
cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor")
|
351 |
self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor")
|
|
|
374 |
fn=generate,
|
375 |
inputs=[
|
376 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
377 |
+
init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
378 |
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
|
379 |
seed,
|
380 |
boxes,
|
|
|
385 |
|
386 |
with gr.Column():
|
387 |
gr.Examples(
|
388 |
+
#examples=[
|
389 |
+
# ["a ginger kitten and a gray puppy in a yard", "2,3;6,7", "1,4,5,8,9", "10"],
|
390 |
+
# ["a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17"],
|
391 |
+
#],
|
392 |
+
#inputs=[prompt, subject_token_indices, filter_token_indices, num_tokens],
|
393 |
examples=[
|
394 |
+
[
|
395 |
+
PROMPT1, "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
|
396 |
+
25, 18, 3, 1, 1,
|
397 |
+
7.5, 1, 5, 0.2, 8,
|
398 |
+
286,
|
399 |
+
],
|
400 |
+
[
|
401 |
+
PROMPT2, "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
|
402 |
+
18, 12, 3, 1, 1,
|
403 |
+
7.5, 1, 5, 0.2, 8,
|
404 |
+
216,
|
405 |
+
],
|
406 |
+
[
|
407 |
+
PROMPT3, "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
|
408 |
+
18, 12, 3, 1, 1,
|
409 |
+
7.5, 1, 5, 0.2, 8,
|
410 |
+
156,
|
411 |
+
],
|
412 |
+
],
|
413 |
+
fn=build_example_layout,
|
414 |
+
inputs=[
|
415 |
+
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
416 |
+
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
417 |
+
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
|
418 |
+
seed,
|
419 |
],
|
420 |
+
outputs=[boxes, sketchpad, layout_image],
|
421 |
+
run_on_click=True,
|
422 |
)
|
423 |
|
424 |
gr.HTML(FOOTNOTE)
|
bounded_attention.py
CHANGED
@@ -38,6 +38,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
38 |
start_step_size=30,
|
39 |
end_step_size=10,
|
40 |
loss_stopping_value=0.2,
|
|
|
41 |
cross_mask_threshold=0.2,
|
42 |
self_mask_threshold=0.2,
|
43 |
delta_refine_mask_steps=5,
|
@@ -73,6 +74,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
73 |
self.start_step_size = start_step_size
|
74 |
self.step_size_coef = (end_step_size - start_step_size) / max_guidance_iter
|
75 |
self.loss_stopping_value = loss_stopping_value
|
|
|
76 |
self.cross_mask_threshold = cross_mask_threshold
|
77 |
self.self_mask_threshold = self_mask_threshold
|
78 |
|
@@ -367,7 +369,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
367 |
|
368 |
def _obtain_masks(self, resolution, return_boxes=False, return_existing=False, batch_size=None, device=None):
|
369 |
return_boxes = return_boxes or (return_existing and self.self_masks is None)
|
370 |
-
if return_boxes or self.cur_step < self.
|
371 |
masks = self._convert_boxes_to_masks(resolution, device=device).unsqueeze(0)
|
372 |
if batch_size is not None:
|
373 |
masks = masks.expand(batch_size, *masks.shape[1:])
|
@@ -406,21 +408,6 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
406 |
self_masks = F.interpolate(self_masks, resolution, mode='nearest-exact')
|
407 |
return self_masks.flatten(start_dim=2).bool()
|
408 |
|
409 |
-
def _cluster_self_maps(self): # b s n
|
410 |
-
self_maps = self._compute_maps(self.mean_self_map) # b n m
|
411 |
-
if self.pca_rank is not None:
|
412 |
-
dtype = self_maps.dtype
|
413 |
-
_, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
|
414 |
-
self_maps = torch.matmul(self_maps, eigen_vectors.to(dtype=dtype))
|
415 |
-
|
416 |
-
clustering_results = self.clustering(self_maps, centers=self.centers)
|
417 |
-
self.clustering.num_init = 1 # clustering is deterministic after the first time
|
418 |
-
self.centers = clustering_results.centers
|
419 |
-
clusters = clustering_results.labels
|
420 |
-
num_clusters = self.clustering.n_clusters
|
421 |
-
self._save_maps(clusters / num_clusters, f'clusters')
|
422 |
-
return num_clusters, clusters
|
423 |
-
|
424 |
def _build_self_masks(self):
|
425 |
c, clusters = self._cluster_self_maps() # b n
|
426 |
cluster_masks = torch.stack([(clusters == cluster_index) for cluster_index in range(c)], dim=2) # b n c
|
@@ -444,6 +431,21 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
444 |
self._save_maps(self_masks, 'self_masks')
|
445 |
return self_masks
|
446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
def _obtain_cross_masks(self, resolution, scale=10):
|
448 |
maps = self._compute_maps(self.mean_cross_map, resolution=resolution) # b n k
|
449 |
maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
|
|
|
38 |
start_step_size=30,
|
39 |
end_step_size=10,
|
40 |
loss_stopping_value=0.2,
|
41 |
+
min_clustering_step=15,
|
42 |
cross_mask_threshold=0.2,
|
43 |
self_mask_threshold=0.2,
|
44 |
delta_refine_mask_steps=5,
|
|
|
74 |
self.start_step_size = start_step_size
|
75 |
self.step_size_coef = (end_step_size - start_step_size) / max_guidance_iter
|
76 |
self.loss_stopping_value = loss_stopping_value
|
77 |
+
self.min_clustering_step = min_clustering_step
|
78 |
self.cross_mask_threshold = cross_mask_threshold
|
79 |
self.self_mask_threshold = self_mask_threshold
|
80 |
|
|
|
369 |
|
370 |
def _obtain_masks(self, resolution, return_boxes=False, return_existing=False, batch_size=None, device=None):
|
371 |
return_boxes = return_boxes or (return_existing and self.self_masks is None)
|
372 |
+
if return_boxes or self.cur_step < self.min_clustering_step:
|
373 |
masks = self._convert_boxes_to_masks(resolution, device=device).unsqueeze(0)
|
374 |
if batch_size is not None:
|
375 |
masks = masks.expand(batch_size, *masks.shape[1:])
|
|
|
408 |
self_masks = F.interpolate(self_masks, resolution, mode='nearest-exact')
|
409 |
return self_masks.flatten(start_dim=2).bool()
|
410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
def _build_self_masks(self):
|
412 |
c, clusters = self._cluster_self_maps() # b n
|
413 |
cluster_masks = torch.stack([(clusters == cluster_index) for cluster_index in range(c)], dim=2) # b n c
|
|
|
431 |
self._save_maps(self_masks, 'self_masks')
|
432 |
return self_masks
|
433 |
|
434 |
+
def _cluster_self_maps(self): # b s n
|
435 |
+
self_maps = self._compute_maps(self.mean_self_map) # b n m
|
436 |
+
if self.pca_rank is not None:
|
437 |
+
dtype = self_maps.dtype
|
438 |
+
_, _, eigen_vectors = torch.pca_lowrank(self_maps.float(), self.pca_rank)
|
439 |
+
self_maps = torch.matmul(self_maps, eigen_vectors.to(dtype=dtype))
|
440 |
+
|
441 |
+
clustering_results = self.clustering(self_maps, centers=self.centers)
|
442 |
+
self.clustering.num_init = 1 # clustering is deterministic after the first time
|
443 |
+
self.centers = clustering_results.centers
|
444 |
+
clusters = clustering_results.labels
|
445 |
+
num_clusters = self.clustering.n_clusters
|
446 |
+
self._save_maps(clusters / num_clusters, f'clusters')
|
447 |
+
return num_clusters, clusters
|
448 |
+
|
449 |
def _obtain_cross_masks(self, resolution, scale=10):
|
450 |
maps = self._compute_maps(self.mean_cross_map, resolution=resolution) # b n k
|
451 |
maps = F.sigmoid(scale * (maps - self.cross_mask_threshold))
|