Kamichanw commited on
Commit
5a12027
·
verified ·
1 Parent(s): 3663af6

Update vqa_accuracy.py

Browse files
Files changed (1) hide show
  1. vqa_accuracy.py +7 -18
vqa_accuracy.py CHANGED
@@ -5,7 +5,7 @@ import re
5
  _DESCRIPTION = """
6
  VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
7
  $$
8
- \\text{Acc}(\\textit{ans}) = \\min \\left( \\frac{\\text{# humans that said \\textit{ans}}{3}, 1 \\right)
9
  $$
10
  Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
11
  """
@@ -17,9 +17,9 @@ Args:
17
  references (`list` of `str` lists): Ground truth answers.
18
  answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
19
  questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
20
- precision (`int`, defaults to 2): The precision of results.
21
  Returns:
22
- visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher accuracy means higher accuracy.
23
 
24
  """
25
 
@@ -250,14 +250,7 @@ class VQAAccuracy(evaluate.Metric):
250
  ],
251
  )
252
 
253
- def _compute(
254
- self,
255
- predictions,
256
- references,
257
- answer_types=None,
258
- question_types=None,
259
- precision=2,
260
- ):
261
  if answer_types is None:
262
  answer_types = [None] * len(predictions)
263
 
@@ -300,21 +293,17 @@ class VQAAccuracy(evaluate.Metric):
300
  ques_type_dict[ques_type].append(vqa_acc)
301
 
302
  # the following key names follow the naming of the official evaluation results
303
- result = {"overall": round(100 * sum(total) / len(total), precision)}
304
 
305
  if len(ans_type_dict) > 0:
306
  result["perAnswerType"] = {
307
- ans_type: round(
308
- 100 * sum(accuracy_list) / len(accuracy_list), precision
309
- )
310
  for ans_type, accuracy_list in ans_type_dict.items()
311
  }
312
 
313
  if len(ques_type_dict) > 0:
314
  result["perQuestionType"] = {
315
- ques_type: round(
316
- 100 * sum(accuracy_list) / len(accuracy_list), precision
317
- )
318
  for ques_type, accuracy_list in ques_type_dict.items()
319
  }
320
 
 
5
  _DESCRIPTION = """
6
  VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
7
  $$
8
+ \\text{Acc}(ans) = \\min \\left( \\frac{\\text{# humans that said }ans}{3}, 1 \\right)
9
  $$
10
  Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
11
  """
 
17
  references (`list` of `str` lists): Ground truth answers.
18
  answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
19
  questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
20
+
21
  Returns:
22
+ visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 100.
23
 
24
  """
25
 
 
250
  ],
251
  )
252
 
253
+ def _compute(self, predictions, references, answer_types=None, question_types=None):
 
 
 
 
 
 
 
254
  if answer_types is None:
255
  answer_types = [None] * len(predictions)
256
 
 
293
  ques_type_dict[ques_type].append(vqa_acc)
294
 
295
  # the following key names follow the naming of the official evaluation results
296
+ result = {"overall": 100 * sum(total) / len(total)}
297
 
298
  if len(ans_type_dict) > 0:
299
  result["perAnswerType"] = {
300
+ ans_type: 100 * sum(accuracy_list) / len(accuracy_list)
 
 
301
  for ans_type, accuracy_list in ans_type_dict.items()
302
  }
303
 
304
  if len(ques_type_dict) > 0:
305
  result["perQuestionType"] = {
306
+ ques_type: 100 * sum(accuracy_list) / len(accuracy_list)
 
 
307
  for ques_type, accuracy_list in ques_type_dict.items()
308
  }
309