chenzizhao commited on
Commit
2f56479
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +4 -0
  3. README.md +12 -0
  4. adapter.py +244 -0
  5. app.py +335 -0
  6. config_generator.py +51 -0
  7. dataset_splits/dev_imgs.pkl +0 -0
  8. dataset_splits/test_imgs.pkl +0 -0
  9. dataset_splits/train_imgs.pkl +0 -0
  10. requirements.txt +9 -0
  11. tangram_pngs/page-A.png +0 -0
  12. tangram_pngs/page-B.png +0 -0
  13. tangram_pngs/page-C.png +0 -0
  14. tangram_pngs/page-D.png +0 -0
  15. tangram_pngs/page-E.png +0 -0
  16. tangram_pngs/page-F.png +0 -0
  17. tangram_pngs/page-G.png +0 -0
  18. tangram_pngs/page-H.png +0 -0
  19. tangram_pngs/page-I.png +0 -0
  20. tangram_pngs/page-J.png +0 -0
  21. tangram_pngs/page-K.png +0 -0
  22. tangram_pngs/page-L.png +0 -0
  23. tangram_pngs/page1-0.png +0 -0
  24. tangram_pngs/page1-1.png +0 -0
  25. tangram_pngs/page1-10.png +0 -0
  26. tangram_pngs/page1-103.png +0 -0
  27. tangram_pngs/page1-105.png +0 -0
  28. tangram_pngs/page1-106.png +0 -0
  29. tangram_pngs/page1-107.png +0 -0
  30. tangram_pngs/page1-108.png +0 -0
  31. tangram_pngs/page1-109.png +0 -0
  32. tangram_pngs/page1-110.png +0 -0
  33. tangram_pngs/page1-112.png +0 -0
  34. tangram_pngs/page1-113.png +0 -0
  35. tangram_pngs/page1-114.png +0 -0
  36. tangram_pngs/page1-116.png +0 -0
  37. tangram_pngs/page1-117.png +0 -0
  38. tangram_pngs/page1-118.png +0 -0
  39. tangram_pngs/page1-119.png +0 -0
  40. tangram_pngs/page1-122.png +0 -0
  41. tangram_pngs/page1-125.png +0 -0
  42. tangram_pngs/page1-128.png +0 -0
  43. tangram_pngs/page1-129.png +0 -0
  44. tangram_pngs/page1-13.png +0 -0
  45. tangram_pngs/page1-130.png +0 -0
  46. tangram_pngs/page1-132.png +0 -0
  47. tangram_pngs/page1-133.png +0 -0
  48. tangram_pngs/page1-136.png +0 -0
  49. tangram_pngs/page1-137.png +0 -0
  50. tangram_pngs/page1-14.png +0 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .DS_Store
4
+ .vscode
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Respect
3
+ emoji: 🫡
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
adapter.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from functools import cache
4
+ from pathlib import Path
5
+ from typing import List, Set, Tuple, TypeVar
6
+
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from utils import device, nested_apply, sorted_list
11
+
12
+ RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa
13
+
14
+
15
+ # Name type, newtype of str. e.g. "page4-249.png"
16
+ N = TypeVar('N')
17
+
18
+ ALPHABET = 'ABCDEFGHIJ' # we only have 10 images
19
+ LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413,
20
+ 420, 475, 5339, 634, 17960, 32002] # A - J and <end_of_utterance> and <\s> and 'select' and 'deselect'
21
+
22
+
23
+ MINI_DECODER = {
24
+ 384: 'D',
25
+ # 2: '</s>',
26
+ 32002: '<end_of_utterance>',
27
+ 420: 'G', 17960: 'elect',
28
+ 330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J',
29
+ 634: 'des', 315: 'I', 413: 'E', 382: 'H'}
30
+
31
+
32
+ class AlphabeticNameHash:
33
+
34
+ @cache
35
+ def __init__(self, context: List[N]) -> None:
36
+ self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)}
37
+ self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)}
38
+
39
+ def hash(self, im: N) -> str:
40
+ return self._forward_map[im]
41
+
42
+ def unhash(self, i: str) -> N:
43
+ return self._backward_map[i]
44
+
45
+ def valid_hash(self, i: str) -> bool:
46
+ return i in self._backward_map
47
+
48
+
49
+ class IdeficsAdapter:
50
+
51
+ PAD_TOKEN_ID = 0
52
+ LABEL_MASK_ID = 32001 # idefics2: image_token_id
53
+ LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS
54
+ LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\
55
+ .index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
56
+ SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
57
+
58
+ def __init__(self, image_folder: str, processor) -> None:
59
+ self.t_max_length = 2048
60
+ self.image_folder = Path(image_folder)
61
+ self.image_cache = {}
62
+ self.processor = processor
63
+ self.tokenizer = self.processor.tokenizer
64
+
65
+ def get_image(self, im_name: N) -> Image.Image:
66
+ if im_name not in self.image_cache:
67
+ self.image_cache[im_name] = Image.open(
68
+ self.image_folder.joinpath(im_name))
69
+ return self.image_cache[im_name]
70
+
71
+ def unhash(self, context: List[N], c: str):
72
+ return AlphabeticNameHash(tuple(context)).unhash(c)
73
+
74
+ def valid_hash(self, context: List[N], c: str):
75
+ return AlphabeticNameHash(tuple(context)).valid_hash(c)
76
+
77
+ def parse(self, context: List[N], decoded_out: str,
78
+ currently_selected: List[N]) -> List[str]:
79
+ h = AlphabeticNameHash(tuple(context))
80
+ logging.debug(f"{context=}")
81
+ # do inference
82
+ logging.debug(f"{decoded_out=}")
83
+ selection, deselection = self.parse_raw(decoded_out)
84
+
85
+ hashed_currently_selected = {h.hash(n) for n in currently_selected}
86
+ desel_to_remove = deselection - hashed_currently_selected
87
+ if len(desel_to_remove) > 0:
88
+ logging.debug(f"warn! {desel_to_remove=}")
89
+ deselection = deselection - desel_to_remove
90
+
91
+ sel_to_remove = selection & hashed_currently_selected
92
+ if len(sel_to_remove) > 0:
93
+ logging.debug(f"warn! {sel_to_remove=}")
94
+ selection = selection - sel_to_remove
95
+
96
+ logging.debug("post strict cleaning")
97
+ logging.debug(f"{selection=}")
98
+ logging.debug(f"{deselection=}")
99
+
100
+ model_clicks = selection | deselection
101
+ logging.debug(f"{model_clicks=}")
102
+ model_clicks_png = [h.unhash(n)
103
+ for n in model_clicks if h.valid_hash(n)]
104
+ logging.debug(f"{model_clicks_png=}")
105
+ return model_clicks_png
106
+
107
+ @staticmethod
108
+ def parse_raw(text: str) -> Tuple[Set[N], Set[N]]:
109
+ last_answer = text.strip()
110
+ if ":" in text:
111
+ last_answer_pattern = r":.*$"
112
+ xs = re.findall(last_answer_pattern, text)
113
+ last_answer = xs[0].removeprefix(":").strip()
114
+ xs = re.search(RE_PATTERN, last_answer)
115
+ if xs is None:
116
+ print(f"{last_answer=}")
117
+ print("did not pass regex")
118
+ return set(), set()
119
+
120
+ select_pattern = r"(?<!de)select( [A-J])+$"
121
+ xs = re.search(select_pattern, last_answer)
122
+ if xs is not None:
123
+ xs = xs.group()
124
+ selections = set(xs.split(" ")[1:]) if xs else set()
125
+
126
+ deselect_pattern = r"^deselect( [A-J])+"
127
+ xs = re.search(deselect_pattern, last_answer)
128
+ if xs is not None:
129
+ xs = xs.group()
130
+ deselections = set(xs.split(" ")[1:]) if xs else set()
131
+
132
+ return selections, deselections
133
+
134
+ def compose(self, context, chats, previous_selected, hash_images, padding):
135
+ select_accum, deselect_accum, clickss = self.unfold_select_deselect(
136
+ previous_selected)
137
+
138
+ select_accum = select_accum + [[]]
139
+ deselect_accum = deselect_accum + [[]]
140
+ previous_selected = [[]] + previous_selected # old states pre click
141
+ assert len(chats) == len(select_accum) == len(
142
+ deselect_accum) == len(previous_selected)
143
+
144
+ messages, images = self.build_processor_input(
145
+ context, chats, select_accum, deselect_accum, previous_selected, hash_images, omit_last_answer=True, sort_names=True, omit_context=False, chat_feedback=None)
146
+ prompt = self.processor.apply_chat_template(
147
+ messages, add_generation_prompt=True)
148
+ prompt = prompt.strip()
149
+ logging.debug(prompt)
150
+ # Keep consistent with train_script
151
+ inputs = self.processor(
152
+ text=prompt, images=images,
153
+ padding=padding, truncation=True, max_length=self.t_max_length,
154
+ return_tensors="pt")
155
+ return inputs
156
+
157
+ def build_processor_input(self, image_pngs: List[N], chats: List[str],
158
+ select_accum: List[List[N]],
159
+ deselect_accum: List[List[N]],
160
+ pre_click_selected_accum: List[List[N]],
161
+ hash_image: bool, omit_last_answer: bool,
162
+ sort_names: bool, omit_context: bool,
163
+ chat_feedback: str, ):
164
+ def _text_content(text): return {"type": "text", "text": text}
165
+
166
+ def _image_content(): return {"type": "image"}
167
+
168
+ def _user_prompt(content): return {"role": "user", "content": content}
169
+
170
+ def _assistant_prompt(content): return {
171
+ "role": "assistant", "content": content}
172
+
173
+ def _system_prompt(content): return {
174
+ "role": "system", "content": content}
175
+
176
+ def _current_state(selected: List[N]):
177
+ if len(selected) == 0:
178
+ return 'none is selected'
179
+ return f'{" ".join(selected)} currently selected'
180
+
181
+ def _listener_action(select: List[N], deselect: List[N]):
182
+ if len(select) == 0 and len(deselect) == 0:
183
+ return 'nothing'
184
+ if len(select) == 0:
185
+ return f'deselect {" ".join(deselect)}'
186
+ if len(deselect) == 0:
187
+ return f'select {" ".join(select)}'
188
+ return f'deselect {" ".join(deselect)} select {" ".join(select)}'
189
+
190
+ func = AlphabeticNameHash(tuple(image_pngs)).hash if hash_image else id
191
+ context, select_accum, deselect_accum, pre_click_selected_accum = nested_apply(
192
+ func, (image_pngs, select_accum, deselect_accum, pre_click_selected_accum))
193
+
194
+ prompt = []
195
+ images = []
196
+ if not omit_context:
197
+ images = [self.get_image(im) for im in image_pngs]
198
+ images_and_names_content = []
199
+ for im_name in context:
200
+ images_and_names_content.append(_image_content())
201
+ images_and_names_content.append(_text_content(im_name))
202
+ prompt.append(_system_prompt(images_and_names_content))
203
+ if not len(chats) == len(select_accum) == len(deselect_accum) == len(pre_click_selected_accum):
204
+ logging.error(f"{chats=}")
205
+ logging.error(f"{select_accum=}")
206
+ logging.error(f"{deselect_accum=}")
207
+ logging.error(f"{pre_click_selected_accum=}")
208
+ assert False
209
+ for i, (chat, select, deselect, pre_click_selected) in enumerate(
210
+ zip(chats, select_accum, deselect_accum, pre_click_selected_accum)):
211
+ if sort_names:
212
+ select = sorted(select)
213
+ deselect = sorted(deselect)
214
+ pre_click_selected = sorted(pre_click_selected)
215
+
216
+ prompt.append(_system_prompt(
217
+ [_text_content(_current_state(pre_click_selected))]))
218
+ prompt.append(_user_prompt([_text_content(chat)]))
219
+ prompt.append(_assistant_prompt(
220
+ [_text_content(_listener_action(select, deselect))]))
221
+ if omit_last_answer:
222
+ # idefics2 has processor.apply_chat_template(messages, add_generation_prompt=True) instead
223
+ prompt.pop(-1)
224
+ if chat_feedback is not None:
225
+ prompt.append(_user_prompt([_text_content(chat_feedback)]))
226
+ return prompt, images
227
+
228
+ def unfold_select_deselect(self, previous_selected: List[List[N]]) -> Tuple[List[N], List[N], List[N]]:
229
+ # currently selected AFTER i-th turn
230
+ num_turns = len(previous_selected)
231
+ selected: List[List[str]] = [] # turn-wise selection
232
+ deselected: List[List[str]] = [] # turn-wise deselection
233
+ clicks: List[List[str]] = []
234
+ # combining turn-wise newly selected and newly deselected
235
+ prev_selected = set()
236
+ for turn in range(num_turns):
237
+ curr_selected = set(previous_selected[turn])
238
+ newly_selected = curr_selected - prev_selected
239
+ newly_deselected = prev_selected - curr_selected
240
+ selected.append(sorted_list(newly_selected))
241
+ deselected.append(sorted_list(newly_deselected))
242
+ clicks.append(sorted_list(newly_selected | newly_deselected))
243
+ prev_selected = curr_selected.copy()
244
+ return selected, deselected, clicks
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import os
4
+ from typing import Any, Dict, List
5
+
6
+ import gradio as gr # type: ignore
7
+ import PIL.Image as Image
8
+ import PIL.ImageOps as ImageOps
9
+ import spaces # type: ignore
10
+ import torch
11
+ from peft import PeftModel # type: ignore
12
+ from transformers import AutoProcessor # type: ignore
13
+ from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
14
+
15
+ from adapter import IdeficsAdapter
16
+ from config_generator import GameConfig, generate_game_config
17
+ from utils import device, nested_to_device, sorted_list
18
+ import copy
19
+
20
+ ### Constants
21
+ css="""
22
+ .radio-group .wrap {
23
+ display: grid;
24
+ grid-template-columns: repeat(5, 1fr);
25
+ grid-template-rows: repeat(5, 1fr);
26
+ width: 100%;
27
+ height: 100%
28
+ }
29
+ """
30
+ IMG_DIR = "tangram_pngs"
31
+
32
+
33
+ ### Bot server
34
+
35
+ GEN_KWS: Dict[str, Any] = {
36
+ "max_new_tokens": 10,
37
+ "do_sample": True,
38
+ "temperature": 1.0,
39
+ "output_logits": True,
40
+ "return_dict_in_generate": True,
41
+ "remove_invalid_values": True, # just to be safe
42
+ "renormalize_logits": True,
43
+ "suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS
44
+ }
45
+
46
+ @spaces.GPU(duration=20)
47
+ def get_model_response( # predict
48
+ model: PeftModel, adapter_name: str, adapter: IdeficsAdapter,
49
+ image_paths: List[str], chat : str, chats: List[str],
50
+ previous_selected: List[List[str]]
51
+ ) -> List[str]:
52
+ if model.active_adapter != adapter_name:
53
+ model.set_adapter(adapter_name)
54
+
55
+ model.to(device())
56
+
57
+ new_chats = chats + [chat]
58
+ currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
59
+ model_input: Dict[str, Any] = adapter.compose( # type: ignore
60
+ image_paths, new_chats, previous_selected, True, False)
61
+ model_input = nested_to_device(model_input) # type: ignore
62
+
63
+ with torch.inference_mode(), torch.autocast(device_type=device().type,
64
+ dtype=torch.bfloat16):
65
+ model_output = model.generate(**model_input, **GEN_KWS) # type: ignore
66
+
67
+ decoded_out: str = adapter.tokenizer.decode( # type: ignore
68
+ model_output.sequences[0], skip_special_tokens=True)
69
+ model_clicks = adapter.parse(
70
+ image_paths, decoded_out, currently_selected) # type: ignore
71
+
72
+ if len(model_clicks) == 0:
73
+ logging.warning("empty clicks by model")
74
+ model_clicks = [image_paths[0]]
75
+ logging.debug(f"{image_paths=}")
76
+ logging.debug(f"selecting {model_clicks}")
77
+ prob = -1
78
+ else:
79
+ prob = -3
80
+ logging.debug(f"{prob=}")
81
+ logging.info(f"User input: {chat}")
82
+ logging.info(f"Model selected: {model_clicks}")
83
+ logging.debug(f"Model output: {decoded_out}")
84
+ return model_clicks
85
+
86
+
87
+ def get_model() -> PeftModel:
88
+ model_id = 'lil-lab/respect'
89
+ checkpoint = "HuggingFaceM4/idefics2-8b"
90
+ model = Idefics2ForConditionalGeneration.from_pretrained( # type: ignore
91
+ checkpoint, torch_dtype=torch.bfloat16,
92
+ )
93
+ peft_model = PeftModel.from_pretrained( # type: ignore
94
+ model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
95
+
96
+ # Add other adapter - hack to avoid conflict
97
+ lora_config = copy.deepcopy(peft_model.active_peft_config)
98
+ targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters()
99
+ if 'lora' in n))
100
+ lora_config.target_modules = targets
101
+ peft_model.add_adapter("r0", lora_config)
102
+ peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0",
103
+ peft_config=lora_config)
104
+ return peft_model
105
+
106
+ def get_processor() -> Idefics2Processor:
107
+ checkpoint = "HuggingFaceM4/idefics2-8b"
108
+ processor = AutoProcessor.from_pretrained( # type: ignore
109
+ checkpoint, do_image_splitting=False,
110
+ size={"longest_edge": 224, "shortest_edge": 224})
111
+ return processor # type: ignore
112
+
113
+ def get_adapter() -> IdeficsAdapter:
114
+ processor = get_processor()
115
+ return IdeficsAdapter(IMG_DIR, processor)
116
+
117
+
118
+ ### Game logic
119
+
120
+ @dataclasses.dataclass(frozen=False)
121
+ class GameState:
122
+ config: GameConfig
123
+ adapter_name: str
124
+ chats: List[str]
125
+ currently_selected: List[str]
126
+ selected_accum: List[List[str]]
127
+ clicks_accum: List[List[str]]
128
+ turn: int = 0
129
+
130
+ def has_ended(self):
131
+ return self.has_successfully_ended() or self.turn >= 10
132
+
133
+ def has_successfully_ended(self):
134
+ return set(self.currently_selected) == set(self.config.targets)
135
+
136
+ ### UI helpers
137
+
138
+ def serialize_conversation(self):
139
+ output = [f"Turn {i+1}: {message}"
140
+ for i, message in enumerate(self.chats)]
141
+ return "\n".join(output)
142
+
143
+ def markup_images(self):
144
+ context = self.config.speaker_context
145
+ targets = self.config.targets
146
+ selected = self.currently_selected
147
+ changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
148
+
149
+ tangram_list = self._display_context(context, targets, changes, selected)
150
+ # return [(img, f"Image {i+1}") for i, img in enumerate(tangram_list)]
151
+ return tangram_list
152
+
153
+ @staticmethod
154
+ def _display_context(context: List[str], targets: List[str],
155
+ changes: List[str], selected: List[str]) -> List[Image.Image]:
156
+ tangram_list: List[Image.Image] = []
157
+ arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA")
158
+ for img in context:
159
+ image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB")
160
+ image = ImageOps.expand(image, border=2, fill="white")
161
+ if img in targets and img in selected: # listener selected a target image
162
+ image = ImageOps.expand(image, border=10, fill="green")
163
+ elif img in targets and img not in selected: # unselected target:
164
+ image = ImageOps.expand(image, border=10, fill="black")
165
+ elif img in selected and img not in targets: # listener selected a wrong image
166
+ image = ImageOps.expand(image, border=10, fill="red")
167
+ else:
168
+ image = ImageOps.expand(image, border=10, fill="white")
169
+ image = ImageOps.expand(image, border=2, fill="white")
170
+ if img in changes:
171
+ image.paste(arrow, (68, 0), mask=arrow)
172
+ tangram_list.append(image)
173
+ return tangram_list
174
+
175
+
176
+ class GameFlow:
177
+
178
+ @classmethod
179
+ def initialize(cls, model_iteration: str) -> GameState:
180
+ config = generate_game_config()
181
+ adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp"
182
+ state = GameState(
183
+ config=config,
184
+ adapter_name=adapter_name,
185
+ chats=[],
186
+ currently_selected=[],
187
+ selected_accum=[],
188
+ clicks_accum=[],
189
+ turn=0,
190
+ )
191
+ return state
192
+
193
+ @classmethod
194
+ def progress(cls, state: GameState, chat: str,
195
+ model: PeftModel,
196
+ adapter: IdeficsAdapter) -> GameState:
197
+ turn = state.turn
198
+ model_context_images = state.config.listener_context
199
+
200
+ model_clicks = get_model_response(
201
+ model, state.adapter_name, adapter,
202
+ model_context_images, chat,
203
+ state.chats, state.selected_accum
204
+ )
205
+
206
+ # symmetric difference (apply deselection, then selection)
207
+ currently_selected2 = sorted_list(
208
+ (set(state.currently_selected) - set(model_clicks)) \
209
+ | (set(model_clicks) - set(state.currently_selected))
210
+ )
211
+
212
+ state2 = GameState(
213
+ # constants
214
+ config=state.config,
215
+ adapter_name=state.adapter_name,
216
+ # updates
217
+ chats=state.chats.copy() + [chat],
218
+ currently_selected=currently_selected2,
219
+ selected_accum=state.selected_accum.copy() + [currently_selected2],
220
+ clicks_accum=state.clicks_accum.copy() + [model_clicks],
221
+ turn=turn+1,
222
+ )
223
+ return state2
224
+
225
+
226
+
227
+ ### UI
228
+
229
+ def create_app_inner():
230
+ ### layout
231
+ gr.Markdown("# Tangram Multi-Reference Game")
232
+ gr.Markdown(
233
+ '### You will be playing a multi-reference games against a model. \
234
+ To start a game, first select whether you wish to play against our \
235
+ initial trained model ("Initial System") or \
236
+ our model at the end of continual learning ("Final System") \
237
+ and press the "Start Game" button. \
238
+ You will take on a "speaker" role at each round. \
239
+ Your goal is to describe this image (via a message in the textbox) \
240
+ so that the model can guess what it is.'
241
+ )
242
+
243
+ gr.Markdown("Targets have black borders. Correctly selected targets have green borders. Incorrectly selected targets have red borders. Actions are marked with yellow dot.")
244
+
245
+ gr.Markdown("The listener cannot see boxes or colors and the order is different.")
246
+
247
+ gr.Markdown(
248
+ '### Press "Send" to submit your action to proceed to the next turn. \
249
+ You have 10 turns in total.'
250
+ )
251
+
252
+ with gr.Row():
253
+ model_iteration = gr.Radio(["Initial System", "Final System"],
254
+ label="Model Iteration",
255
+ value="Final System")
256
+ start_btn = gr.Button("Start Game")
257
+
258
+ with gr.Row():
259
+ current_turn = gr.Textbox(label="TURN")
260
+ success = gr.Textbox(label="Success")
261
+
262
+ with gr.Row():
263
+ image_output = gr.Gallery(
264
+ label="CONTEXT", show_label=False, elem_id="gallery",
265
+ columns=5, rows=2, object_fit="contain", height="250px",
266
+ allow_preview=False, container=True, interactive=False
267
+ )
268
+
269
+ with gr.Row():
270
+ conversation_output = gr.Textbox(label="Interaction History")
271
+ user_input = gr.Textbox(label="Your Message as Speaker", interactive=True)
272
+
273
+ send_btn = gr.Button("Send", interactive=True)
274
+
275
+ ### globals
276
+ model = get_model()
277
+ adapter = get_adapter()
278
+ game_state = gr.State(value=None)
279
+
280
+ ### callbacks
281
+ def output_from_state(state: GameState):
282
+ has_ended = state.has_ended()
283
+ success = "success" if state.has_successfully_ended() else "failure"
284
+ return (
285
+ state.markup_images(), # image_output
286
+ state.serialize_conversation(), # conversation_output
287
+ f"{state.turn+1}/10", # current_turn
288
+ success if has_ended else "n/a", # success
289
+ gr.update(interactive=not has_ended, value=""), # user_input
290
+ gr.update(interactive=not has_ended), # send_btn
291
+ gr.update(interactive=has_ended), # model_iteration
292
+ state, # game_history
293
+ )
294
+
295
+ def on_start_interaction(model_iteration: str):
296
+ assert model_iteration in ["Initial System", "Final System"]
297
+ state = GameFlow.initialize(model_iteration)
298
+ return output_from_state(state)
299
+
300
+ def on_send_message(message: str, state: GameState):
301
+ nonlocal model
302
+ nonlocal adapter
303
+ if message.strip() == "":
304
+ logging.info("Empty message")
305
+ return output_from_state(state)
306
+ state = GameFlow.progress(state, message, model, adapter)
307
+ return output_from_state(state)
308
+
309
+ start_btn.click(
310
+ on_start_interaction,
311
+ inputs=[model_iteration],
312
+ outputs=[image_output, conversation_output, current_turn, success,
313
+ user_input, send_btn, model_iteration, game_state],
314
+ queue=False
315
+ )
316
+
317
+ send_btn.click(
318
+ on_send_message,
319
+ inputs=[user_input, game_state],
320
+ outputs=[image_output, conversation_output, current_turn, success,
321
+ user_input, send_btn, model_iteration, game_state],
322
+ queue=True
323
+ )
324
+
325
+
326
+ def create_app():
327
+ with gr.Blocks(css=css) as app:
328
+ create_app_inner()
329
+ return app
330
+
331
+
332
+ if __name__ == "__main__":
333
+ app = create_app()
334
+ app.queue()
335
+ app.launch()
config_generator.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import functools
3
+ import logging
4
+ import os
5
+ import pickle
6
+ import pprint
7
+ import random
8
+ from typing import List
9
+
10
+ EMPTY_DATA_PATH = "tangram_pngs/"
11
+ SPLIT_PATH = "dataset_splits/"
12
+
13
+
14
+ @dataclasses.dataclass(frozen=True)
15
+ class GameConfig:
16
+ speaker_context: List[str]
17
+ listener_context: List[str]
18
+ targets: List[str]
19
+
20
+
21
+ def generate_game_config() -> GameConfig:
22
+ corpus = _get_data()
23
+ context = random.sample(corpus, 10)
24
+ num_targets = random.randint(3, 5)
25
+ targets = random.sample(context, num_targets)
26
+ listener_order = list(range(10))
27
+ random.shuffle(listener_order)
28
+
29
+ config = GameConfig(
30
+ speaker_context=context,
31
+ listener_context=[context[i] for i in listener_order],
32
+ targets=targets,
33
+ )
34
+ logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}")
35
+ return config
36
+
37
+ @functools.cache
38
+ def _get_data(restricted_dataset: bool=False):
39
+ if not restricted_dataset:
40
+ # 1013 images
41
+ paths = os.listdir(EMPTY_DATA_PATH)
42
+ else:
43
+ # 912 images
44
+ with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f:
45
+ paths = pickle.load(f)
46
+ with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f:
47
+ paths += pickle.load(f)
48
+ paths = [path + ".png" for path in paths]
49
+ dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"]
50
+ paths = [path for path in paths if path != ".DS_Store" and path not in dup_images]
51
+ return paths
dataset_splits/dev_imgs.pkl ADDED
Binary file (1.18 kB). View file
 
dataset_splits/test_imgs.pkl ADDED
Binary file (5.26 kB). View file
 
dataset_splits/train_imgs.pkl ADDED
Binary file (5.26 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ datasets==2.18.0
3
+ transformers==4.40.0
4
+ accelerate==0.29.2
5
+ loralib==0.1.2
6
+ peft==0.10.0
7
+ nltk==3.8.1
8
+ gradio==4.44.1
9
+ spaces==0.30.4
tangram_pngs/page-A.png ADDED
tangram_pngs/page-B.png ADDED
tangram_pngs/page-C.png ADDED
tangram_pngs/page-D.png ADDED
tangram_pngs/page-E.png ADDED
tangram_pngs/page-F.png ADDED
tangram_pngs/page-G.png ADDED
tangram_pngs/page-H.png ADDED
tangram_pngs/page-I.png ADDED
tangram_pngs/page-J.png ADDED
tangram_pngs/page-K.png ADDED
tangram_pngs/page-L.png ADDED
tangram_pngs/page1-0.png ADDED
tangram_pngs/page1-1.png ADDED
tangram_pngs/page1-10.png ADDED
tangram_pngs/page1-103.png ADDED
tangram_pngs/page1-105.png ADDED
tangram_pngs/page1-106.png ADDED
tangram_pngs/page1-107.png ADDED
tangram_pngs/page1-108.png ADDED
tangram_pngs/page1-109.png ADDED
tangram_pngs/page1-110.png ADDED
tangram_pngs/page1-112.png ADDED
tangram_pngs/page1-113.png ADDED
tangram_pngs/page1-114.png ADDED
tangram_pngs/page1-116.png ADDED
tangram_pngs/page1-117.png ADDED
tangram_pngs/page1-118.png ADDED
tangram_pngs/page1-119.png ADDED
tangram_pngs/page1-122.png ADDED
tangram_pngs/page1-125.png ADDED
tangram_pngs/page1-128.png ADDED
tangram_pngs/page1-129.png ADDED
tangram_pngs/page1-13.png ADDED
tangram_pngs/page1-130.png ADDED
tangram_pngs/page1-132.png ADDED
tangram_pngs/page1-133.png ADDED
tangram_pngs/page1-136.png ADDED
tangram_pngs/page1-137.png ADDED
tangram_pngs/page1-14.png ADDED