flow3rdown commited on
Commit
9aac1ca
·
1 Parent(s): 9e48e7a
Files changed (1) hide show
  1. app.py +36 -7
app.py CHANGED
@@ -80,11 +80,11 @@ processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
80
 
81
  def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, question_id):
82
  # (I, I) -> (T, ?)
83
- head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id]
84
 
85
  inputs = tokenizer(
86
- tokenizer.sep_token.join([analogy_ent2token[head_id] + " " + head_ent_text, "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]),
87
- tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]),
88
  truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True)
89
  sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']]
90
  inputs['sep_idx'] = torch.tensor(sep_idx)
@@ -108,8 +108,37 @@ def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, que
108
 
109
  return answer
110
 
 
111
  def single_inference_tti(head_txt, head_id, tail_txt, tail_id, question_img, question_id):
112
- return head_txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def blended_inference_iti(head_img, head_id, tail_txt, tail_id, question_img, question_id):
115
  return tail_txt
@@ -143,7 +172,7 @@ def single_tab_iit():
143
  inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent],
144
  outputs=[output_text],
145
  cache_examples=False,
146
- run_on_click=False
147
  )
148
 
149
  def single_tab_tti():
@@ -174,7 +203,7 @@ def single_tab_tti():
174
  inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent],
175
  outputs=[output_text],
176
  cache_examples=False,
177
- run_on_click=False
178
  )
179
 
180
  def blended_tab_iti():
@@ -205,7 +234,7 @@ def blended_tab_iti():
205
  inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent],
206
  outputs=[output_text],
207
  cache_examples=False,
208
- run_on_click=False
209
  )
210
 
211
 
 
80
 
81
  def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, question_id):
82
  # (I, I) -> (T, ?)
83
+ ques_ent_text = ent2description[question_id]
84
 
85
  inputs = tokenizer(
86
+ tokenizer.sep_token.join([analogy_ent2token[head_id] + " ", "[R] ", analogy_ent2token[tail_id] + " "]),
87
+ tokenizer.sep_token.join([analogy_ent2token[question_id] + " " + ques_ent_text, "[R] ", "[MASK]"]),
88
  truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True)
89
  sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']]
90
  inputs['sep_idx'] = torch.tensor(sep_idx)
 
108
 
109
  return answer
110
 
111
+
112
  def single_inference_tti(head_txt, head_id, tail_txt, tail_id, question_img, question_id):
113
+ # # (T, T) -> (I, ?)
114
+ # head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id]
115
+
116
+ # inputs = tokenizer(
117
+ # tokenizer.sep_token.join([analogy_ent2token[head_id] + " " + head_ent_text, "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]),
118
+ # tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]),
119
+ # truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True)
120
+ # sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']]
121
+ # inputs['sep_idx'] = torch.tensor(sep_idx)
122
+ # inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone()
123
+ # for i, idx in enumerate(sep_idx):
124
+ # inputs['attention_mask'][i, :idx[2], idx[2]:] = 0
125
+
126
+ # # image
127
+ # pixel_values = processor(images=[head_img, tail_img], return_tensors='pt')['pixel_values'].squeeze()
128
+ # inputs['pixel_values'] = pixel_values.unsqueeze(0)
129
+
130
+ # input_ids = inputs['input_ids']
131
+
132
+ # model_output = mkgformer.model(**inputs, return_dict=True)
133
+ # logits = model_output[0].logits
134
+ # bsz = input_ids.shape[0]
135
+
136
+ # _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz
137
+ # mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity
138
+ # answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]]
139
+
140
+ return answer
141
+
142
 
143
  def blended_inference_iti(head_img, head_id, tail_txt, tail_id, question_img, question_id):
144
  return tail_txt
 
172
  inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent],
173
  outputs=[output_text],
174
  cache_examples=False,
175
+ run_on_click=True
176
  )
177
 
178
  def single_tab_tti():
 
203
  inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent],
204
  outputs=[output_text],
205
  cache_examples=False,
206
+ run_on_click=True
207
  )
208
 
209
  def blended_tab_iti():
 
234
  inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent],
235
  outputs=[output_text],
236
  cache_examples=False,
237
+ run_on_click=True
238
  )
239
 
240