omer11a commited on
Commit
a47cf5c
·
1 Parent(s): de34da3

Fixed some stuff

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +10 -23
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Bounded Attention
3
- emoji: 🦀
4
  colorFrom: pink
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.22.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-sa-3.0
 
1
  ---
2
  title: Bounded Attention
3
+ emoji:
4
  colorFrom: pink
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.43.2
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-sa-3.0
app.py CHANGED
@@ -13,6 +13,7 @@ from functools import partial
13
 
14
  RESOLUTION = 512
15
  MIN_SIZE = 0.01
 
16
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
17
 
18
 
@@ -80,10 +81,6 @@ def generate(
80
  seed,
81
  boxes
82
  ):
83
- if 'boxes' not in boxes:
84
- boxes['boxes'] = []
85
-
86
- boxes = boxes['boxes']
87
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
88
  if len(boxes) != len(subject_token_indices):
89
  raise ValueError("""
@@ -115,13 +112,11 @@ def convert_token_indices(token_indices, nested=False):
115
  return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
116
 
117
 
118
- def draw(boxes, mask):
119
- print('Called draw')
120
- print('before boxes', boxes)
121
  if mask.ndim == 3:
122
- mask = 255 - mask[..., 0]
123
 
124
- mask = (mask != 0).astype('uint8') * 255
125
  if mask.sum() > 0:
126
  x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
127
  y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
@@ -130,9 +125,8 @@ def draw(boxes, mask):
130
 
131
  if (x2 - x1 > MIN_SIZE) and (y2 - y1 > MIN_SIZE):
132
  boxes.append((x1, y1, x2, y2))
133
- layout_image = draw_boxes(np.array(boxes) * RESOLUTION)
134
 
135
- print('after boxes', boxes)
136
  return [boxes, None, layout_image]
137
 
138
 
@@ -140,10 +134,10 @@ def draw_boxes(boxes):
140
  if len(boxes) == 0:
141
  return None
142
 
143
- image = Image.new('RGB', (RESOLUTION, RESOLUTION), (255, 255, 255))
 
144
  drawing = ImageDraw.Draw(image)
145
- print(boxes)
146
- for i, box in enumerate(boxes):
147
  drawing.rectangle(box, outline=COLORS[i % len(COLORS)], width=4)
148
 
149
  return image
@@ -294,16 +288,9 @@ def main():
294
 
295
  boxes = gr.State([])
296
 
297
- demo.load(
298
- clear,
299
- inputs=[batch_size],
300
- outputs=[boxes, sketch_pad, layout_image, out_images],
301
- queue=False
302
- )
303
-
304
  sketch_pad.edit(
305
  draw,
306
- inputs=[boxes, sketch_pad],
307
  outputs=[boxes, sketch_pad, layout_image],
308
  queue=False,
309
  )
@@ -347,7 +334,7 @@ def main():
347
  description = """<p> The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>"""
348
  gr.HTML(description)
349
 
350
- demo.queue(concurrency_count=1, api_open=False)
351
  demo.launch(share=False, show_api=False, show_error=True)
352
 
353
  if __name__ == '__main__':
 
13
 
14
  RESOLUTION = 512
15
  MIN_SIZE = 0.01
16
+ WHITE = 255
17
  COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
18
 
19
 
 
81
  seed,
82
  boxes
83
  ):
 
 
 
 
84
  subject_token_indices = convert_token_indices(subject_token_indices, nested=True)
85
  if len(boxes) != len(subject_token_indices):
86
  raise ValueError("""
 
112
  return [int(index.strip()) for index in token_indices.split(',') if len(index.strip()) > 0]
113
 
114
 
115
+ def draw(boxes, mask, layout_image):
 
 
116
  if mask.ndim == 3:
117
+ mask = WHITE - mask[..., 0]
118
 
119
+ mask = (mask != 0).astype('uint8') * WHITE
120
  if mask.sum() > 0:
121
  x1x2 = np.where(mask.max(0) != 0)[0] / RESOLUTION
122
  y1y2 = np.where(mask.max(1) != 0)[0] / RESOLUTION
 
125
 
126
  if (x2 - x1 > MIN_SIZE) and (y2 - y1 > MIN_SIZE):
127
  boxes.append((x1, y1, x2, y2))
128
+ layout_image = draw_boxes(boxes)
129
 
 
130
  return [boxes, None, layout_image]
131
 
132
 
 
134
  if len(boxes) == 0:
135
  return None
136
 
137
+ boxes = np.array(boxes) * RESOLUTION
138
+ image = Image.new('RGB', (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE))
139
  drawing = ImageDraw.Draw(image)
140
+ for i, box in enumerate(boxes.astype(int).tolist()):
 
141
  drawing.rectangle(box, outline=COLORS[i % len(COLORS)], width=4)
142
 
143
  return image
 
288
 
289
  boxes = gr.State([])
290
 
 
 
 
 
 
 
 
291
  sketch_pad.edit(
292
  draw,
293
+ inputs=[boxes, sketch_pad, layout_image],
294
  outputs=[boxes, sketch_pad, layout_image],
295
  queue=False,
296
  )
 
334
  description = """<p> The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p>"""
335
  gr.HTML(description)
336
 
337
+ demo.queue(max_size=50)
338
  demo.launch(share=False, show_api=False, show_error=True)
339
 
340
  if __name__ == '__main__':