mixed-modality-search commited on
Commit
cf62495
·
1 Parent(s): 0161c89
Files changed (1) hide show
  1. main.py +105 -0
main.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import numpy as np
5
+ from cryptography.fernet import Fernet
6
+ from collections import defaultdict
7
+ from sklearn.metrics import ndcg_score
8
+
9
+ def load_and_decrypt_qrel(secret_key):
10
+ try:
11
+ with open("data/answer.enc", "rb") as enc_file:
12
+ encrypted_data = enc_file.read()
13
+ cipher = Fernet(secret_key.encode())
14
+ decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
15
+ raw_data = json.loads(decrypted_data)
16
+
17
+ # qrel_dict: dataset -> query_id -> {corpus_id: score}
18
+ qrel_dict = defaultdict(lambda: defaultdict(dict))
19
+ for dataset, records in raw_data.items():
20
+ for item in records:
21
+ qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
22
+ qrel_dict[dataset][qid][cid] = score
23
+ return qrel_dict
24
+ except Exception as e:
25
+ raise ValueError(f"Failed to decrypt answer file: {str(e)}")
26
+
27
+ def recall_at_k(rank_list, relevant_ids, k=1):
28
+ return int(any(item in relevant_ids for item in rank_list[:k]))
29
+
30
+ def ndcg_at_k(rank_list, rel_dict, k):
31
+ all_items = list(dict.fromkeys(rank_list + list(rel_dict.keys())))
32
+
33
+ y_true = [rel_dict.get(item, 0) for item in all_items]
34
+
35
+ y_score = [len(all_items) - i for i in range(len(all_items))]
36
+
37
+ return ndcg_score([y_true], [y_score], k=k)
38
+
39
+ def evaluate(pred_data, qrel_dict):
40
+ results = {}
41
+ for dataset, queries in pred_data.items():
42
+ if dataset not in qrel_dict:
43
+ continue
44
+
45
+ recall_1, ndcg_10, ndcg_100 = [], [], []
46
+
47
+ for item in queries:
48
+ qid = item["query_id"]
49
+ rank_list = item["rank_list"].split(",")
50
+ rank_list = [x.strip() for x in rank_list if x.strip()]
51
+ rel_dict = qrel_dict[dataset].get(qid, {})
52
+ relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
53
+
54
+ recall_1.append(recall_at_k(rank_list, relevant_ids, 1))
55
+ ndcg_10.append(ndcg_at_k(rank_list, rel_dict, 10))
56
+ ndcg_100.append(ndcg_at_k(rank_list, rel_dict, 100))
57
+
58
+ results[dataset] = {
59
+ "Recall@1": round(np.mean(recall_1) * 100, 2),
60
+ "NDCG@10": round(np.mean(ndcg_10) * 100, 2),
61
+ "NDCG@100": round(np.mean(ndcg_100) * 100, 2),
62
+ }
63
+
64
+ return results
65
+
66
+ # ==== Gradio Wrapper ====
67
+ def process_json(file):
68
+ try:
69
+ pred_data = json.load(open(file))
70
+ except Exception as e:
71
+ return f"Invalid JSON format: {str(e)}"
72
+
73
+ try:
74
+ secret_key = os.getenv("SECRET_KEY")
75
+ qrel_dict = load_and_decrypt_qrel(secret_key)
76
+ except Exception as e:
77
+ return str(e)
78
+
79
+ try:
80
+ metrics = evaluate(pred_data, qrel_dict)
81
+ return json.dumps(metrics, indent=2)
82
+ except Exception as e:
83
+ return f"Error during evaluation: {str(e)}"
84
+
85
+ # ==== Launch Gradio App ====
86
+ def main_gradio():
87
+ example_json = '''{
88
+ "mscoco": [
89
+ {"query_id": "1", "rank_list": "5, 2, 8"},
90
+ {"query_id": "2", "rank_list": "9, 1, 3"}
91
+ ],
92
+ "google_wit": [
93
+ {"query_id": "3", "rank_list": "11, 5, 22"}
94
+ ]
95
+ }'''
96
+ gr.Interface(
97
+ fn=process_json,
98
+ inputs=gr.File(label="Upload Prediction JSON"),
99
+ outputs=gr.Textbox(label="Evaluation Metrics"),
100
+ title="Mixed-Modality Retrieval Evaluation",
101
+ description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
102
+ ).launch(share=True)
103
+
104
+ if __name__ == "__main__":
105
+ main_gradio()