mixed-modality-search commited on
Commit
6eb346b
·
1 Parent(s): c63e6b2
Files changed (1) hide show
  1. main.py +24 -17
main.py CHANGED
@@ -24,11 +24,11 @@ def load_and_decrypt_qrel(secret_key):
24
  except Exception as e:
25
  raise ValueError(f"Failed to decrypt answer file: {str(e)}")
26
 
27
- def recall_at_k(rank_list, relevant_ids, k=1):
28
- return int(any(item in relevant_ids for item in rank_list[:k]))
29
 
30
- def ndcg_at_k(rank_list, rel_dict, k):
31
- all_items = list(dict.fromkeys(rank_list + list(rel_dict.keys())))
32
 
33
  y_true = [rel_dict.get(item, 0) for item in all_items]
34
 
@@ -46,14 +46,14 @@ def evaluate(pred_data, qrel_dict):
46
 
47
  for item in queries:
48
  qid = item["query_id"]
49
- rank_list = item["rank_list"].split(",")
50
- rank_list = [x.strip() for x in rank_list if x.strip()]
51
  rel_dict = qrel_dict[dataset].get(qid, {})
52
  relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
53
 
54
- recall_1.append(recall_at_k(rank_list, relevant_ids, 1))
55
- ndcg_10.append(ndcg_at_k(rank_list, rel_dict, 10))
56
- ndcg_100.append(ndcg_at_k(rank_list, rel_dict, 100))
57
 
58
  results[dataset] = {
59
  "Recall@1": round(np.mean(recall_1) * 100, 2),
@@ -85,19 +85,26 @@ def process_json(file):
85
  # ==== Launch Gradio App ====
86
  def main_gradio():
87
  example_json = '''{
88
- "mscoco": [
89
- {"query_id": "1", "rank_list": "5, 2, 8"},
90
- {"query_id": "2", "rank_list": "9, 1, 3"}
91
  ],
92
- "google_wit": [
93
- {"query_id": "3", "rank_list": "11, 5, 22"}
 
 
 
 
 
 
 
94
  ]
95
  }'''
96
  gr.Interface(
97
  fn=process_json,
98
- inputs=gr.File(label="Upload Prediction JSON"),
99
- outputs=gr.Textbox(label="Evaluation Metrics"),
100
- title="Mixed-Modality Retrieval Evaluation",
101
  description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
102
  ).launch(share=True)
103
 
 
24
  except Exception as e:
25
  raise ValueError(f"Failed to decrypt answer file: {str(e)}")
26
 
27
+ def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
28
+ return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
29
 
30
+ def ndcg_at_k(corpus_top_100_list, rel_dict, k):
31
+ all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
32
 
33
  y_true = [rel_dict.get(item, 0) for item in all_items]
34
 
 
46
 
47
  for item in queries:
48
  qid = item["query_id"]
49
+ corpus_top_100_list = item["corpus_top_100_list"].split(",")
50
+ corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
51
  rel_dict = qrel_dict[dataset].get(qid, {})
52
  relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
53
 
54
+ recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
55
+ ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
56
+ ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))
57
 
58
  results[dataset] = {
59
  "Recall@1": round(np.mean(recall_1) * 100, 2),
 
85
  # ==== Launch Gradio App ====
86
  def main_gradio():
87
  example_json = '''{
88
+ "Google_WIT": [
89
+ {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
90
+ {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
91
  ],
92
+ "MSCOCO": [
93
+ {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
94
+ {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
95
+ ]
96
+ "OVEN": [
97
+ {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
98
+ ]
99
+ "VisualNews": [
100
+ {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
101
  ]
102
  }'''
103
  gr.Interface(
104
  fn=process_json,
105
+ inputs=gr.File(label="Upload Retrieval Result (JSON)"),
106
+ outputs=gr.Textbox(label="Results"),
107
+ title="Automated Evaluation of MixBench",
108
  description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
109
  ).launch(share=True)
110