File size: 5,023 Bytes
cf62495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eb346b
 
cf62495
6eb346b
 
cf62495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eb346b
 
cf62495
 
 
6eb346b
 
 
cf62495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec99498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf62495
 
6eb346b
 
 
cf62495
6eb346b
 
 
ec99498
6eb346b
 
ec99498
6eb346b
 
cf62495
 
ec99498
 
 
cf62495
 
6eb346b
 
 
ec99498
 
cf62495
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import json
import os
import numpy as np
from cryptography.fernet import Fernet
from collections import defaultdict
from sklearn.metrics import ndcg_score

def load_and_decrypt_qrel(secret_key):
    try:
        with open("data/answer.enc", "rb") as enc_file:
            encrypted_data = enc_file.read()
        cipher = Fernet(secret_key.encode())
        decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
        raw_data = json.loads(decrypted_data)

        # qrel_dict: dataset -> query_id -> {corpus_id: score}
        qrel_dict = defaultdict(lambda: defaultdict(dict))
        for dataset, records in raw_data.items():
            for item in records:
                qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
                qrel_dict[dataset][qid][cid] = score
        return qrel_dict
    except Exception as e:
        raise ValueError(f"Failed to decrypt answer file: {str(e)}")

def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
    return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))

def ndcg_at_k(corpus_top_100_list, rel_dict, k):
    all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))

    y_true = [rel_dict.get(item, 0) for item in all_items]

    y_score = [len(all_items) - i for i in range(len(all_items))]

    return ndcg_score([y_true], [y_score], k=k)

def evaluate(pred_data, qrel_dict):
    results = {}
    for dataset, queries in pred_data.items():
        if dataset not in qrel_dict:
            continue

        recall_1, ndcg_10, ndcg_100 = [], [], []

        for item in queries:
            qid = item["query_id"]
            corpus_top_100_list = item["corpus_top_100_list"].split(",")
            corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
            rel_dict = qrel_dict[dataset].get(qid, {})
            relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]

            recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
            ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
            ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))

        results[dataset] = {
            "Recall@1": round(np.mean(recall_1) * 100, 2),
            "NDCG@10": round(np.mean(ndcg_10) * 100, 2),
            "NDCG@100": round(np.mean(ndcg_100) * 100, 2),
        }

    return results

# ==== Gradio Wrapper ====
def process_json(file):
    try:
        pred_data = json.load(open(file))
    except Exception as e:
        return f"Invalid JSON format: {str(e)}"

    try:
        secret_key = os.getenv("SECRET_KEY")
        qrel_dict = load_and_decrypt_qrel(secret_key)
    except Exception as e:
        return str(e)

    try:
        metrics = evaluate(pred_data, qrel_dict)
        return json.dumps(metrics, indent=2)
    except Exception as e:
        return f"Error during evaluation: {str(e)}"

# ==== Launch Gradio App ====
# def main_gradio():
#     example_json = '''{
#   "Google_WIT": [
#     {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
#     {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
#   ],
#   "MSCOCO": [
#     {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
#     {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
#   ]
#   "OVEN": [
#     {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
#   ]
#   "VisualNews": [
#     {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
#   ]
# }'''
#     gr.Interface(
#         fn=process_json,
#         inputs=gr.File(label="Upload Retrieval Result (JSON)"),
#         outputs=gr.Textbox(label="Results"),
#         title="Automated Evaluation of MixBench",
#         description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
#     ).launch(share=True)

def main_gradio():
    example_json = '''{
  "Google_WIT": [
    {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
    {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
  ],
  "MSCOCO": [
    {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
    {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
  ],
  "OVEN": [
    {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
  ],
  "VisualNews": [
    {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
  ]
}'''
    # Convert \n to <br> and wrap example in <pre> to preserve indentation
    formatted_example = f"<pre>{example_json}</pre>"

    gr.Interface(
        fn=process_json,
        inputs=gr.File(label="Upload Retrieval Result (JSON)"),
        outputs=gr.Textbox(label="Results"),
        title="Automated Evaluation of MixBench",
        description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.<br><br>"
                    "Example input format:<br>" + formatted_example
    ).launch(share=True)

if __name__ == "__main__":
    main_gradio()