omer11a commited on
Commit
0f77b33
·
1 Parent(s): a55826f

Fixed some stuff

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -1,9 +1,3 @@
1
- from diffusers import DDIMScheduler
2
- from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
3
- from injection_utils import regiter_attention_editor_diffusers
4
- from bounded_attention import BoundedAttention
5
- from pytorch_lightning import seed_everything
6
-
7
  import spaces
8
  import gradio as gr
9
  import torch
@@ -11,6 +5,12 @@ import nltk
11
  import numpy as np
12
  from PIL import Image, ImageDraw
13
 
 
 
 
 
 
 
14
  from functools import partial
15
 
16
  RESOLUTION = 256
@@ -113,14 +113,15 @@ def convert_token_indices(token_indices, nested=False):
113
  def draw(sketchpad):
114
  boxes = []
115
  for i, layer in enumerate(sketchpad['layers']):
116
- mask = (layer != 0)
117
- if mask.sum() < 0:
118
- raise gr.Error(f'Box in layer {i} is too small')
119
-
120
- x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
121
- y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
122
- y1, y2 = y1y2.min(), y1y2.max()
123
- x1, x2 = x1x2.min(), x1x2.max()
 
124
 
125
  if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
126
  raise gr.Error(f'Box in layer {i} is too small')
@@ -220,20 +221,20 @@ def main():
220
  )
221
 
222
  subject_token_indices = gr.Textbox(
223
- label="The token indices of each subject (separate indices for the same subject with commas, and between different subjects with semicolons)",
224
  )
225
 
226
  filter_token_indices = gr.Textbox(
227
- label="The token indices to filter, i.e. conjunctions, number, postional relations, etc. (if left empty, this will be automatically inferred)",
228
  )
229
 
230
  num_tokens = gr.Textbox(
231
- label="The number of tokens in the prompt (can be left empty, but we recommend filling this, so we can verify your input, as sometimes rare words are split into more than one token)",
232
  )
233
 
234
  with gr.Row():
235
  sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION)
236
- layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION)
237
 
238
  with gr.Row():
239
  clear_button = gr.Button(value='Clear')
@@ -278,7 +279,7 @@ def main():
278
  </div>
279
  """
280
  gr.HTML(description)
281
- batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples")
282
  init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=25, label="Initial step size")
283
  final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=10, label="Final step size")
284
  num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  import torch
 
5
  import numpy as np
6
  from PIL import Image, ImageDraw
7
 
8
+ from diffusers import DDIMScheduler
9
+ from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
10
+ from injection_utils import regiter_attention_editor_diffusers
11
+ from bounded_attention import BoundedAttention
12
+ from pytorch_lightning import seed_everything
13
+
14
  from functools import partial
15
 
16
  RESOLUTION = 256
 
113
  def draw(sketchpad):
114
  boxes = []
115
  for i, layer in enumerate(sketchpad['layers']):
116
+ non_zeros = layer.nonzero()
117
+ x1 = x2 = y1 = y2 = 0
118
+ if len(non_zeros[0]) > 0:
119
+ x1x2 = non_zeros[1] / layer.shape[1]
120
+ y1y2 = non_zeros[0] / layer.shape[0]
121
+ x1 = x1x2.min()
122
+ x2 = x1x2.max()
123
+ y1 = y1y2.min()
124
+ y2 = y1y2.max()
125
 
126
  if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
127
  raise gr.Error(f'Box in layer {i} is too small')
 
221
  )
222
 
223
  subject_token_indices = gr.Textbox(
224
+ label="The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)",
225
  )
226
 
227
  filter_token_indices = gr.Textbox(
228
+ label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)",
229
  )
230
 
231
  num_tokens = gr.Textbox(
232
+ label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)",
233
  )
234
 
235
  with gr.Row():
236
  sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION)
237
+ layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION, scale=1)
238
 
239
  with gr.Row():
240
  clear_button = gr.Button(value='Clear')
 
279
  </div>
280
  """
281
  gr.HTML(description)
282
+ batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (currently limited to one sample)")
283
  init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=25, label="Initial step size")
284
  final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=10, label="Final step size")
285
  num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject")