Spaces:
Runtime error
Runtime error
Commit
·
2f56479
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +4 -0
- README.md +12 -0
- adapter.py +244 -0
- app.py +335 -0
- config_generator.py +51 -0
- dataset_splits/dev_imgs.pkl +0 -0
- dataset_splits/test_imgs.pkl +0 -0
- dataset_splits/train_imgs.pkl +0 -0
- requirements.txt +9 -0
- tangram_pngs/page-A.png +0 -0
- tangram_pngs/page-B.png +0 -0
- tangram_pngs/page-C.png +0 -0
- tangram_pngs/page-D.png +0 -0
- tangram_pngs/page-E.png +0 -0
- tangram_pngs/page-F.png +0 -0
- tangram_pngs/page-G.png +0 -0
- tangram_pngs/page-H.png +0 -0
- tangram_pngs/page-I.png +0 -0
- tangram_pngs/page-J.png +0 -0
- tangram_pngs/page-K.png +0 -0
- tangram_pngs/page-L.png +0 -0
- tangram_pngs/page1-0.png +0 -0
- tangram_pngs/page1-1.png +0 -0
- tangram_pngs/page1-10.png +0 -0
- tangram_pngs/page1-103.png +0 -0
- tangram_pngs/page1-105.png +0 -0
- tangram_pngs/page1-106.png +0 -0
- tangram_pngs/page1-107.png +0 -0
- tangram_pngs/page1-108.png +0 -0
- tangram_pngs/page1-109.png +0 -0
- tangram_pngs/page1-110.png +0 -0
- tangram_pngs/page1-112.png +0 -0
- tangram_pngs/page1-113.png +0 -0
- tangram_pngs/page1-114.png +0 -0
- tangram_pngs/page1-116.png +0 -0
- tangram_pngs/page1-117.png +0 -0
- tangram_pngs/page1-118.png +0 -0
- tangram_pngs/page1-119.png +0 -0
- tangram_pngs/page1-122.png +0 -0
- tangram_pngs/page1-125.png +0 -0
- tangram_pngs/page1-128.png +0 -0
- tangram_pngs/page1-129.png +0 -0
- tangram_pngs/page1-13.png +0 -0
- tangram_pngs/page1-130.png +0 -0
- tangram_pngs/page1-132.png +0 -0
- tangram_pngs/page1-133.png +0 -0
- tangram_pngs/page1-136.png +0 -0
- tangram_pngs/page1-137.png +0 -0
- tangram_pngs/page1-14.png +0 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.pyc
|
3 |
+
.DS_Store
|
4 |
+
.vscode
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Respect
|
3 |
+
emoji: 🫡
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
adapter.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from functools import cache
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Set, Tuple, TypeVar
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from utils import device, nested_apply, sorted_list
|
11 |
+
|
12 |
+
RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa
|
13 |
+
|
14 |
+
|
15 |
+
# Name type, newtype of str. e.g. "page4-249.png"
|
16 |
+
N = TypeVar('N')
|
17 |
+
|
18 |
+
ALPHABET = 'ABCDEFGHIJ' # we only have 10 images
|
19 |
+
LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413,
|
20 |
+
420, 475, 5339, 634, 17960, 32002] # A - J and <end_of_utterance> and <\s> and 'select' and 'deselect'
|
21 |
+
|
22 |
+
|
23 |
+
MINI_DECODER = {
|
24 |
+
384: 'D',
|
25 |
+
# 2: '</s>',
|
26 |
+
32002: '<end_of_utterance>',
|
27 |
+
420: 'G', 17960: 'elect',
|
28 |
+
330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J',
|
29 |
+
634: 'des', 315: 'I', 413: 'E', 382: 'H'}
|
30 |
+
|
31 |
+
|
32 |
+
class AlphabeticNameHash:
|
33 |
+
|
34 |
+
@cache
|
35 |
+
def __init__(self, context: List[N]) -> None:
|
36 |
+
self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)}
|
37 |
+
self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)}
|
38 |
+
|
39 |
+
def hash(self, im: N) -> str:
|
40 |
+
return self._forward_map[im]
|
41 |
+
|
42 |
+
def unhash(self, i: str) -> N:
|
43 |
+
return self._backward_map[i]
|
44 |
+
|
45 |
+
def valid_hash(self, i: str) -> bool:
|
46 |
+
return i in self._backward_map
|
47 |
+
|
48 |
+
|
49 |
+
class IdeficsAdapter:
|
50 |
+
|
51 |
+
PAD_TOKEN_ID = 0
|
52 |
+
LABEL_MASK_ID = 32001 # idefics2: image_token_id
|
53 |
+
LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS
|
54 |
+
LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\
|
55 |
+
.index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
|
56 |
+
SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
|
57 |
+
|
58 |
+
def __init__(self, image_folder: str, processor) -> None:
|
59 |
+
self.t_max_length = 2048
|
60 |
+
self.image_folder = Path(image_folder)
|
61 |
+
self.image_cache = {}
|
62 |
+
self.processor = processor
|
63 |
+
self.tokenizer = self.processor.tokenizer
|
64 |
+
|
65 |
+
def get_image(self, im_name: N) -> Image.Image:
|
66 |
+
if im_name not in self.image_cache:
|
67 |
+
self.image_cache[im_name] = Image.open(
|
68 |
+
self.image_folder.joinpath(im_name))
|
69 |
+
return self.image_cache[im_name]
|
70 |
+
|
71 |
+
def unhash(self, context: List[N], c: str):
|
72 |
+
return AlphabeticNameHash(tuple(context)).unhash(c)
|
73 |
+
|
74 |
+
def valid_hash(self, context: List[N], c: str):
|
75 |
+
return AlphabeticNameHash(tuple(context)).valid_hash(c)
|
76 |
+
|
77 |
+
def parse(self, context: List[N], decoded_out: str,
|
78 |
+
currently_selected: List[N]) -> List[str]:
|
79 |
+
h = AlphabeticNameHash(tuple(context))
|
80 |
+
logging.debug(f"{context=}")
|
81 |
+
# do inference
|
82 |
+
logging.debug(f"{decoded_out=}")
|
83 |
+
selection, deselection = self.parse_raw(decoded_out)
|
84 |
+
|
85 |
+
hashed_currently_selected = {h.hash(n) for n in currently_selected}
|
86 |
+
desel_to_remove = deselection - hashed_currently_selected
|
87 |
+
if len(desel_to_remove) > 0:
|
88 |
+
logging.debug(f"warn! {desel_to_remove=}")
|
89 |
+
deselection = deselection - desel_to_remove
|
90 |
+
|
91 |
+
sel_to_remove = selection & hashed_currently_selected
|
92 |
+
if len(sel_to_remove) > 0:
|
93 |
+
logging.debug(f"warn! {sel_to_remove=}")
|
94 |
+
selection = selection - sel_to_remove
|
95 |
+
|
96 |
+
logging.debug("post strict cleaning")
|
97 |
+
logging.debug(f"{selection=}")
|
98 |
+
logging.debug(f"{deselection=}")
|
99 |
+
|
100 |
+
model_clicks = selection | deselection
|
101 |
+
logging.debug(f"{model_clicks=}")
|
102 |
+
model_clicks_png = [h.unhash(n)
|
103 |
+
for n in model_clicks if h.valid_hash(n)]
|
104 |
+
logging.debug(f"{model_clicks_png=}")
|
105 |
+
return model_clicks_png
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def parse_raw(text: str) -> Tuple[Set[N], Set[N]]:
|
109 |
+
last_answer = text.strip()
|
110 |
+
if ":" in text:
|
111 |
+
last_answer_pattern = r":.*$"
|
112 |
+
xs = re.findall(last_answer_pattern, text)
|
113 |
+
last_answer = xs[0].removeprefix(":").strip()
|
114 |
+
xs = re.search(RE_PATTERN, last_answer)
|
115 |
+
if xs is None:
|
116 |
+
print(f"{last_answer=}")
|
117 |
+
print("did not pass regex")
|
118 |
+
return set(), set()
|
119 |
+
|
120 |
+
select_pattern = r"(?<!de)select( [A-J])+$"
|
121 |
+
xs = re.search(select_pattern, last_answer)
|
122 |
+
if xs is not None:
|
123 |
+
xs = xs.group()
|
124 |
+
selections = set(xs.split(" ")[1:]) if xs else set()
|
125 |
+
|
126 |
+
deselect_pattern = r"^deselect( [A-J])+"
|
127 |
+
xs = re.search(deselect_pattern, last_answer)
|
128 |
+
if xs is not None:
|
129 |
+
xs = xs.group()
|
130 |
+
deselections = set(xs.split(" ")[1:]) if xs else set()
|
131 |
+
|
132 |
+
return selections, deselections
|
133 |
+
|
134 |
+
def compose(self, context, chats, previous_selected, hash_images, padding):
|
135 |
+
select_accum, deselect_accum, clickss = self.unfold_select_deselect(
|
136 |
+
previous_selected)
|
137 |
+
|
138 |
+
select_accum = select_accum + [[]]
|
139 |
+
deselect_accum = deselect_accum + [[]]
|
140 |
+
previous_selected = [[]] + previous_selected # old states pre click
|
141 |
+
assert len(chats) == len(select_accum) == len(
|
142 |
+
deselect_accum) == len(previous_selected)
|
143 |
+
|
144 |
+
messages, images = self.build_processor_input(
|
145 |
+
context, chats, select_accum, deselect_accum, previous_selected, hash_images, omit_last_answer=True, sort_names=True, omit_context=False, chat_feedback=None)
|
146 |
+
prompt = self.processor.apply_chat_template(
|
147 |
+
messages, add_generation_prompt=True)
|
148 |
+
prompt = prompt.strip()
|
149 |
+
logging.debug(prompt)
|
150 |
+
# Keep consistent with train_script
|
151 |
+
inputs = self.processor(
|
152 |
+
text=prompt, images=images,
|
153 |
+
padding=padding, truncation=True, max_length=self.t_max_length,
|
154 |
+
return_tensors="pt")
|
155 |
+
return inputs
|
156 |
+
|
157 |
+
def build_processor_input(self, image_pngs: List[N], chats: List[str],
|
158 |
+
select_accum: List[List[N]],
|
159 |
+
deselect_accum: List[List[N]],
|
160 |
+
pre_click_selected_accum: List[List[N]],
|
161 |
+
hash_image: bool, omit_last_answer: bool,
|
162 |
+
sort_names: bool, omit_context: bool,
|
163 |
+
chat_feedback: str, ):
|
164 |
+
def _text_content(text): return {"type": "text", "text": text}
|
165 |
+
|
166 |
+
def _image_content(): return {"type": "image"}
|
167 |
+
|
168 |
+
def _user_prompt(content): return {"role": "user", "content": content}
|
169 |
+
|
170 |
+
def _assistant_prompt(content): return {
|
171 |
+
"role": "assistant", "content": content}
|
172 |
+
|
173 |
+
def _system_prompt(content): return {
|
174 |
+
"role": "system", "content": content}
|
175 |
+
|
176 |
+
def _current_state(selected: List[N]):
|
177 |
+
if len(selected) == 0:
|
178 |
+
return 'none is selected'
|
179 |
+
return f'{" ".join(selected)} currently selected'
|
180 |
+
|
181 |
+
def _listener_action(select: List[N], deselect: List[N]):
|
182 |
+
if len(select) == 0 and len(deselect) == 0:
|
183 |
+
return 'nothing'
|
184 |
+
if len(select) == 0:
|
185 |
+
return f'deselect {" ".join(deselect)}'
|
186 |
+
if len(deselect) == 0:
|
187 |
+
return f'select {" ".join(select)}'
|
188 |
+
return f'deselect {" ".join(deselect)} select {" ".join(select)}'
|
189 |
+
|
190 |
+
func = AlphabeticNameHash(tuple(image_pngs)).hash if hash_image else id
|
191 |
+
context, select_accum, deselect_accum, pre_click_selected_accum = nested_apply(
|
192 |
+
func, (image_pngs, select_accum, deselect_accum, pre_click_selected_accum))
|
193 |
+
|
194 |
+
prompt = []
|
195 |
+
images = []
|
196 |
+
if not omit_context:
|
197 |
+
images = [self.get_image(im) for im in image_pngs]
|
198 |
+
images_and_names_content = []
|
199 |
+
for im_name in context:
|
200 |
+
images_and_names_content.append(_image_content())
|
201 |
+
images_and_names_content.append(_text_content(im_name))
|
202 |
+
prompt.append(_system_prompt(images_and_names_content))
|
203 |
+
if not len(chats) == len(select_accum) == len(deselect_accum) == len(pre_click_selected_accum):
|
204 |
+
logging.error(f"{chats=}")
|
205 |
+
logging.error(f"{select_accum=}")
|
206 |
+
logging.error(f"{deselect_accum=}")
|
207 |
+
logging.error(f"{pre_click_selected_accum=}")
|
208 |
+
assert False
|
209 |
+
for i, (chat, select, deselect, pre_click_selected) in enumerate(
|
210 |
+
zip(chats, select_accum, deselect_accum, pre_click_selected_accum)):
|
211 |
+
if sort_names:
|
212 |
+
select = sorted(select)
|
213 |
+
deselect = sorted(deselect)
|
214 |
+
pre_click_selected = sorted(pre_click_selected)
|
215 |
+
|
216 |
+
prompt.append(_system_prompt(
|
217 |
+
[_text_content(_current_state(pre_click_selected))]))
|
218 |
+
prompt.append(_user_prompt([_text_content(chat)]))
|
219 |
+
prompt.append(_assistant_prompt(
|
220 |
+
[_text_content(_listener_action(select, deselect))]))
|
221 |
+
if omit_last_answer:
|
222 |
+
# idefics2 has processor.apply_chat_template(messages, add_generation_prompt=True) instead
|
223 |
+
prompt.pop(-1)
|
224 |
+
if chat_feedback is not None:
|
225 |
+
prompt.append(_user_prompt([_text_content(chat_feedback)]))
|
226 |
+
return prompt, images
|
227 |
+
|
228 |
+
def unfold_select_deselect(self, previous_selected: List[List[N]]) -> Tuple[List[N], List[N], List[N]]:
|
229 |
+
# currently selected AFTER i-th turn
|
230 |
+
num_turns = len(previous_selected)
|
231 |
+
selected: List[List[str]] = [] # turn-wise selection
|
232 |
+
deselected: List[List[str]] = [] # turn-wise deselection
|
233 |
+
clicks: List[List[str]] = []
|
234 |
+
# combining turn-wise newly selected and newly deselected
|
235 |
+
prev_selected = set()
|
236 |
+
for turn in range(num_turns):
|
237 |
+
curr_selected = set(previous_selected[turn])
|
238 |
+
newly_selected = curr_selected - prev_selected
|
239 |
+
newly_deselected = prev_selected - curr_selected
|
240 |
+
selected.append(sorted_list(newly_selected))
|
241 |
+
deselected.append(sorted_list(newly_deselected))
|
242 |
+
clicks.append(sorted_list(newly_selected | newly_deselected))
|
243 |
+
prev_selected = curr_selected.copy()
|
244 |
+
return selected, deselected, clicks
|
app.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from typing import Any, Dict, List
|
5 |
+
|
6 |
+
import gradio as gr # type: ignore
|
7 |
+
import PIL.Image as Image
|
8 |
+
import PIL.ImageOps as ImageOps
|
9 |
+
import spaces # type: ignore
|
10 |
+
import torch
|
11 |
+
from peft import PeftModel # type: ignore
|
12 |
+
from transformers import AutoProcessor # type: ignore
|
13 |
+
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
|
14 |
+
|
15 |
+
from adapter import IdeficsAdapter
|
16 |
+
from config_generator import GameConfig, generate_game_config
|
17 |
+
from utils import device, nested_to_device, sorted_list
|
18 |
+
import copy
|
19 |
+
|
20 |
+
### Constants
|
21 |
+
css="""
|
22 |
+
.radio-group .wrap {
|
23 |
+
display: grid;
|
24 |
+
grid-template-columns: repeat(5, 1fr);
|
25 |
+
grid-template-rows: repeat(5, 1fr);
|
26 |
+
width: 100%;
|
27 |
+
height: 100%
|
28 |
+
}
|
29 |
+
"""
|
30 |
+
IMG_DIR = "tangram_pngs"
|
31 |
+
|
32 |
+
|
33 |
+
### Bot server
|
34 |
+
|
35 |
+
GEN_KWS: Dict[str, Any] = {
|
36 |
+
"max_new_tokens": 10,
|
37 |
+
"do_sample": True,
|
38 |
+
"temperature": 1.0,
|
39 |
+
"output_logits": True,
|
40 |
+
"return_dict_in_generate": True,
|
41 |
+
"remove_invalid_values": True, # just to be safe
|
42 |
+
"renormalize_logits": True,
|
43 |
+
"suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS
|
44 |
+
}
|
45 |
+
|
46 |
+
@spaces.GPU(duration=20)
|
47 |
+
def get_model_response( # predict
|
48 |
+
model: PeftModel, adapter_name: str, adapter: IdeficsAdapter,
|
49 |
+
image_paths: List[str], chat : str, chats: List[str],
|
50 |
+
previous_selected: List[List[str]]
|
51 |
+
) -> List[str]:
|
52 |
+
if model.active_adapter != adapter_name:
|
53 |
+
model.set_adapter(adapter_name)
|
54 |
+
|
55 |
+
model.to(device())
|
56 |
+
|
57 |
+
new_chats = chats + [chat]
|
58 |
+
currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
|
59 |
+
model_input: Dict[str, Any] = adapter.compose( # type: ignore
|
60 |
+
image_paths, new_chats, previous_selected, True, False)
|
61 |
+
model_input = nested_to_device(model_input) # type: ignore
|
62 |
+
|
63 |
+
with torch.inference_mode(), torch.autocast(device_type=device().type,
|
64 |
+
dtype=torch.bfloat16):
|
65 |
+
model_output = model.generate(**model_input, **GEN_KWS) # type: ignore
|
66 |
+
|
67 |
+
decoded_out: str = adapter.tokenizer.decode( # type: ignore
|
68 |
+
model_output.sequences[0], skip_special_tokens=True)
|
69 |
+
model_clicks = adapter.parse(
|
70 |
+
image_paths, decoded_out, currently_selected) # type: ignore
|
71 |
+
|
72 |
+
if len(model_clicks) == 0:
|
73 |
+
logging.warning("empty clicks by model")
|
74 |
+
model_clicks = [image_paths[0]]
|
75 |
+
logging.debug(f"{image_paths=}")
|
76 |
+
logging.debug(f"selecting {model_clicks}")
|
77 |
+
prob = -1
|
78 |
+
else:
|
79 |
+
prob = -3
|
80 |
+
logging.debug(f"{prob=}")
|
81 |
+
logging.info(f"User input: {chat}")
|
82 |
+
logging.info(f"Model selected: {model_clicks}")
|
83 |
+
logging.debug(f"Model output: {decoded_out}")
|
84 |
+
return model_clicks
|
85 |
+
|
86 |
+
|
87 |
+
def get_model() -> PeftModel:
|
88 |
+
model_id = 'lil-lab/respect'
|
89 |
+
checkpoint = "HuggingFaceM4/idefics2-8b"
|
90 |
+
model = Idefics2ForConditionalGeneration.from_pretrained( # type: ignore
|
91 |
+
checkpoint, torch_dtype=torch.bfloat16,
|
92 |
+
)
|
93 |
+
peft_model = PeftModel.from_pretrained( # type: ignore
|
94 |
+
model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
|
95 |
+
|
96 |
+
# Add other adapter - hack to avoid conflict
|
97 |
+
lora_config = copy.deepcopy(peft_model.active_peft_config)
|
98 |
+
targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters()
|
99 |
+
if 'lora' in n))
|
100 |
+
lora_config.target_modules = targets
|
101 |
+
peft_model.add_adapter("r0", lora_config)
|
102 |
+
peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0",
|
103 |
+
peft_config=lora_config)
|
104 |
+
return peft_model
|
105 |
+
|
106 |
+
def get_processor() -> Idefics2Processor:
|
107 |
+
checkpoint = "HuggingFaceM4/idefics2-8b"
|
108 |
+
processor = AutoProcessor.from_pretrained( # type: ignore
|
109 |
+
checkpoint, do_image_splitting=False,
|
110 |
+
size={"longest_edge": 224, "shortest_edge": 224})
|
111 |
+
return processor # type: ignore
|
112 |
+
|
113 |
+
def get_adapter() -> IdeficsAdapter:
|
114 |
+
processor = get_processor()
|
115 |
+
return IdeficsAdapter(IMG_DIR, processor)
|
116 |
+
|
117 |
+
|
118 |
+
### Game logic
|
119 |
+
|
120 |
+
@dataclasses.dataclass(frozen=False)
|
121 |
+
class GameState:
|
122 |
+
config: GameConfig
|
123 |
+
adapter_name: str
|
124 |
+
chats: List[str]
|
125 |
+
currently_selected: List[str]
|
126 |
+
selected_accum: List[List[str]]
|
127 |
+
clicks_accum: List[List[str]]
|
128 |
+
turn: int = 0
|
129 |
+
|
130 |
+
def has_ended(self):
|
131 |
+
return self.has_successfully_ended() or self.turn >= 10
|
132 |
+
|
133 |
+
def has_successfully_ended(self):
|
134 |
+
return set(self.currently_selected) == set(self.config.targets)
|
135 |
+
|
136 |
+
### UI helpers
|
137 |
+
|
138 |
+
def serialize_conversation(self):
|
139 |
+
output = [f"Turn {i+1}: {message}"
|
140 |
+
for i, message in enumerate(self.chats)]
|
141 |
+
return "\n".join(output)
|
142 |
+
|
143 |
+
def markup_images(self):
|
144 |
+
context = self.config.speaker_context
|
145 |
+
targets = self.config.targets
|
146 |
+
selected = self.currently_selected
|
147 |
+
changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
|
148 |
+
|
149 |
+
tangram_list = self._display_context(context, targets, changes, selected)
|
150 |
+
# return [(img, f"Image {i+1}") for i, img in enumerate(tangram_list)]
|
151 |
+
return tangram_list
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def _display_context(context: List[str], targets: List[str],
|
155 |
+
changes: List[str], selected: List[str]) -> List[Image.Image]:
|
156 |
+
tangram_list: List[Image.Image] = []
|
157 |
+
arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA")
|
158 |
+
for img in context:
|
159 |
+
image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB")
|
160 |
+
image = ImageOps.expand(image, border=2, fill="white")
|
161 |
+
if img in targets and img in selected: # listener selected a target image
|
162 |
+
image = ImageOps.expand(image, border=10, fill="green")
|
163 |
+
elif img in targets and img not in selected: # unselected target:
|
164 |
+
image = ImageOps.expand(image, border=10, fill="black")
|
165 |
+
elif img in selected and img not in targets: # listener selected a wrong image
|
166 |
+
image = ImageOps.expand(image, border=10, fill="red")
|
167 |
+
else:
|
168 |
+
image = ImageOps.expand(image, border=10, fill="white")
|
169 |
+
image = ImageOps.expand(image, border=2, fill="white")
|
170 |
+
if img in changes:
|
171 |
+
image.paste(arrow, (68, 0), mask=arrow)
|
172 |
+
tangram_list.append(image)
|
173 |
+
return tangram_list
|
174 |
+
|
175 |
+
|
176 |
+
class GameFlow:
|
177 |
+
|
178 |
+
@classmethod
|
179 |
+
def initialize(cls, model_iteration: str) -> GameState:
|
180 |
+
config = generate_game_config()
|
181 |
+
adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp"
|
182 |
+
state = GameState(
|
183 |
+
config=config,
|
184 |
+
adapter_name=adapter_name,
|
185 |
+
chats=[],
|
186 |
+
currently_selected=[],
|
187 |
+
selected_accum=[],
|
188 |
+
clicks_accum=[],
|
189 |
+
turn=0,
|
190 |
+
)
|
191 |
+
return state
|
192 |
+
|
193 |
+
@classmethod
|
194 |
+
def progress(cls, state: GameState, chat: str,
|
195 |
+
model: PeftModel,
|
196 |
+
adapter: IdeficsAdapter) -> GameState:
|
197 |
+
turn = state.turn
|
198 |
+
model_context_images = state.config.listener_context
|
199 |
+
|
200 |
+
model_clicks = get_model_response(
|
201 |
+
model, state.adapter_name, adapter,
|
202 |
+
model_context_images, chat,
|
203 |
+
state.chats, state.selected_accum
|
204 |
+
)
|
205 |
+
|
206 |
+
# symmetric difference (apply deselection, then selection)
|
207 |
+
currently_selected2 = sorted_list(
|
208 |
+
(set(state.currently_selected) - set(model_clicks)) \
|
209 |
+
| (set(model_clicks) - set(state.currently_selected))
|
210 |
+
)
|
211 |
+
|
212 |
+
state2 = GameState(
|
213 |
+
# constants
|
214 |
+
config=state.config,
|
215 |
+
adapter_name=state.adapter_name,
|
216 |
+
# updates
|
217 |
+
chats=state.chats.copy() + [chat],
|
218 |
+
currently_selected=currently_selected2,
|
219 |
+
selected_accum=state.selected_accum.copy() + [currently_selected2],
|
220 |
+
clicks_accum=state.clicks_accum.copy() + [model_clicks],
|
221 |
+
turn=turn+1,
|
222 |
+
)
|
223 |
+
return state2
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
### UI
|
228 |
+
|
229 |
+
def create_app_inner():
|
230 |
+
### layout
|
231 |
+
gr.Markdown("# Tangram Multi-Reference Game")
|
232 |
+
gr.Markdown(
|
233 |
+
'### You will be playing a multi-reference games against a model. \
|
234 |
+
To start a game, first select whether you wish to play against our \
|
235 |
+
initial trained model ("Initial System") or \
|
236 |
+
our model at the end of continual learning ("Final System") \
|
237 |
+
and press the "Start Game" button. \
|
238 |
+
You will take on a "speaker" role at each round. \
|
239 |
+
Your goal is to describe this image (via a message in the textbox) \
|
240 |
+
so that the model can guess what it is.'
|
241 |
+
)
|
242 |
+
|
243 |
+
gr.Markdown("Targets have black borders. Correctly selected targets have green borders. Incorrectly selected targets have red borders. Actions are marked with yellow dot.")
|
244 |
+
|
245 |
+
gr.Markdown("The listener cannot see boxes or colors and the order is different.")
|
246 |
+
|
247 |
+
gr.Markdown(
|
248 |
+
'### Press "Send" to submit your action to proceed to the next turn. \
|
249 |
+
You have 10 turns in total.'
|
250 |
+
)
|
251 |
+
|
252 |
+
with gr.Row():
|
253 |
+
model_iteration = gr.Radio(["Initial System", "Final System"],
|
254 |
+
label="Model Iteration",
|
255 |
+
value="Final System")
|
256 |
+
start_btn = gr.Button("Start Game")
|
257 |
+
|
258 |
+
with gr.Row():
|
259 |
+
current_turn = gr.Textbox(label="TURN")
|
260 |
+
success = gr.Textbox(label="Success")
|
261 |
+
|
262 |
+
with gr.Row():
|
263 |
+
image_output = gr.Gallery(
|
264 |
+
label="CONTEXT", show_label=False, elem_id="gallery",
|
265 |
+
columns=5, rows=2, object_fit="contain", height="250px",
|
266 |
+
allow_preview=False, container=True, interactive=False
|
267 |
+
)
|
268 |
+
|
269 |
+
with gr.Row():
|
270 |
+
conversation_output = gr.Textbox(label="Interaction History")
|
271 |
+
user_input = gr.Textbox(label="Your Message as Speaker", interactive=True)
|
272 |
+
|
273 |
+
send_btn = gr.Button("Send", interactive=True)
|
274 |
+
|
275 |
+
### globals
|
276 |
+
model = get_model()
|
277 |
+
adapter = get_adapter()
|
278 |
+
game_state = gr.State(value=None)
|
279 |
+
|
280 |
+
### callbacks
|
281 |
+
def output_from_state(state: GameState):
|
282 |
+
has_ended = state.has_ended()
|
283 |
+
success = "success" if state.has_successfully_ended() else "failure"
|
284 |
+
return (
|
285 |
+
state.markup_images(), # image_output
|
286 |
+
state.serialize_conversation(), # conversation_output
|
287 |
+
f"{state.turn+1}/10", # current_turn
|
288 |
+
success if has_ended else "n/a", # success
|
289 |
+
gr.update(interactive=not has_ended, value=""), # user_input
|
290 |
+
gr.update(interactive=not has_ended), # send_btn
|
291 |
+
gr.update(interactive=has_ended), # model_iteration
|
292 |
+
state, # game_history
|
293 |
+
)
|
294 |
+
|
295 |
+
def on_start_interaction(model_iteration: str):
|
296 |
+
assert model_iteration in ["Initial System", "Final System"]
|
297 |
+
state = GameFlow.initialize(model_iteration)
|
298 |
+
return output_from_state(state)
|
299 |
+
|
300 |
+
def on_send_message(message: str, state: GameState):
|
301 |
+
nonlocal model
|
302 |
+
nonlocal adapter
|
303 |
+
if message.strip() == "":
|
304 |
+
logging.info("Empty message")
|
305 |
+
return output_from_state(state)
|
306 |
+
state = GameFlow.progress(state, message, model, adapter)
|
307 |
+
return output_from_state(state)
|
308 |
+
|
309 |
+
start_btn.click(
|
310 |
+
on_start_interaction,
|
311 |
+
inputs=[model_iteration],
|
312 |
+
outputs=[image_output, conversation_output, current_turn, success,
|
313 |
+
user_input, send_btn, model_iteration, game_state],
|
314 |
+
queue=False
|
315 |
+
)
|
316 |
+
|
317 |
+
send_btn.click(
|
318 |
+
on_send_message,
|
319 |
+
inputs=[user_input, game_state],
|
320 |
+
outputs=[image_output, conversation_output, current_turn, success,
|
321 |
+
user_input, send_btn, model_iteration, game_state],
|
322 |
+
queue=True
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
def create_app():
|
327 |
+
with gr.Blocks(css=css) as app:
|
328 |
+
create_app_inner()
|
329 |
+
return app
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
app = create_app()
|
334 |
+
app.queue()
|
335 |
+
app.launch()
|
config_generator.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import pprint
|
7 |
+
import random
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
EMPTY_DATA_PATH = "tangram_pngs/"
|
11 |
+
SPLIT_PATH = "dataset_splits/"
|
12 |
+
|
13 |
+
|
14 |
+
@dataclasses.dataclass(frozen=True)
|
15 |
+
class GameConfig:
|
16 |
+
speaker_context: List[str]
|
17 |
+
listener_context: List[str]
|
18 |
+
targets: List[str]
|
19 |
+
|
20 |
+
|
21 |
+
def generate_game_config() -> GameConfig:
|
22 |
+
corpus = _get_data()
|
23 |
+
context = random.sample(corpus, 10)
|
24 |
+
num_targets = random.randint(3, 5)
|
25 |
+
targets = random.sample(context, num_targets)
|
26 |
+
listener_order = list(range(10))
|
27 |
+
random.shuffle(listener_order)
|
28 |
+
|
29 |
+
config = GameConfig(
|
30 |
+
speaker_context=context,
|
31 |
+
listener_context=[context[i] for i in listener_order],
|
32 |
+
targets=targets,
|
33 |
+
)
|
34 |
+
logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}")
|
35 |
+
return config
|
36 |
+
|
37 |
+
@functools.cache
|
38 |
+
def _get_data(restricted_dataset: bool=False):
|
39 |
+
if not restricted_dataset:
|
40 |
+
# 1013 images
|
41 |
+
paths = os.listdir(EMPTY_DATA_PATH)
|
42 |
+
else:
|
43 |
+
# 912 images
|
44 |
+
with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f:
|
45 |
+
paths = pickle.load(f)
|
46 |
+
with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f:
|
47 |
+
paths += pickle.load(f)
|
48 |
+
paths = [path + ".png" for path in paths]
|
49 |
+
dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"]
|
50 |
+
paths = [path for path in paths if path != ".DS_Store" and path not in dup_images]
|
51 |
+
return paths
|
dataset_splits/dev_imgs.pkl
ADDED
Binary file (1.18 kB). View file
|
|
dataset_splits/test_imgs.pkl
ADDED
Binary file (5.26 kB). View file
|
|
dataset_splits/train_imgs.pkl
ADDED
Binary file (5.26 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.0
|
2 |
+
datasets==2.18.0
|
3 |
+
transformers==4.40.0
|
4 |
+
accelerate==0.29.2
|
5 |
+
loralib==0.1.2
|
6 |
+
peft==0.10.0
|
7 |
+
nltk==3.8.1
|
8 |
+
gradio==4.44.1
|
9 |
+
spaces==0.30.4
|
tangram_pngs/page-A.png
ADDED
![]() |
tangram_pngs/page-B.png
ADDED
![]() |
tangram_pngs/page-C.png
ADDED
![]() |
tangram_pngs/page-D.png
ADDED
![]() |
tangram_pngs/page-E.png
ADDED
![]() |
tangram_pngs/page-F.png
ADDED
![]() |
tangram_pngs/page-G.png
ADDED
![]() |
tangram_pngs/page-H.png
ADDED
![]() |
tangram_pngs/page-I.png
ADDED
![]() |
tangram_pngs/page-J.png
ADDED
![]() |
tangram_pngs/page-K.png
ADDED
![]() |
tangram_pngs/page-L.png
ADDED
![]() |
tangram_pngs/page1-0.png
ADDED
![]() |
tangram_pngs/page1-1.png
ADDED
![]() |
tangram_pngs/page1-10.png
ADDED
![]() |
tangram_pngs/page1-103.png
ADDED
![]() |
tangram_pngs/page1-105.png
ADDED
![]() |
tangram_pngs/page1-106.png
ADDED
![]() |
tangram_pngs/page1-107.png
ADDED
![]() |
tangram_pngs/page1-108.png
ADDED
![]() |
tangram_pngs/page1-109.png
ADDED
![]() |
tangram_pngs/page1-110.png
ADDED
![]() |
tangram_pngs/page1-112.png
ADDED
![]() |
tangram_pngs/page1-113.png
ADDED
![]() |
tangram_pngs/page1-114.png
ADDED
![]() |
tangram_pngs/page1-116.png
ADDED
![]() |
tangram_pngs/page1-117.png
ADDED
![]() |
tangram_pngs/page1-118.png
ADDED
![]() |
tangram_pngs/page1-119.png
ADDED
![]() |
tangram_pngs/page1-122.png
ADDED
![]() |
tangram_pngs/page1-125.png
ADDED
![]() |
tangram_pngs/page1-128.png
ADDED
![]() |
tangram_pngs/page1-129.png
ADDED
![]() |
tangram_pngs/page1-13.png
ADDED
![]() |
tangram_pngs/page1-130.png
ADDED
![]() |
tangram_pngs/page1-132.png
ADDED
![]() |
tangram_pngs/page1-133.png
ADDED
![]() |
tangram_pngs/page1-136.png
ADDED
![]() |
tangram_pngs/page1-137.png
ADDED
![]() |
tangram_pngs/page1-14.png
ADDED
![]() |