File size: 4,982 Bytes
cf62495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f941d
cf62495
 
 
 
 
 
 
a2f941d
cf62495
6eb346b
 
cf62495
6eb346b
 
cf62495
 
 
 
 
 
 
 
 
 
 
 
 
6eb346b
 
cf62495
 
 
6eb346b
 
 
cf62495
 
 
 
 
 
 
 
 
 
 
 
 
a2f941d
 
 
 
 
cf62495
 
 
 
 
 
 
 
 
 
a2f941d
 
cf62495
09df82b
a2f941d
09df82b
 
 
 
 
c52dc72
 
 
 
 
 
 
09df82b
a2f941d
09df82b
ec99498
cf62495
acc872d
 
a2f941d
 
acc872d
 
 
a2f941d
acc872d
 
 
a2f941d
 
acc872d
 
 
cf62495
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import json
import os
import numpy as np
from cryptography.fernet import Fernet
from collections import defaultdict
from sklearn.metrics import ndcg_score

def load_and_decrypt_qrel(secret_key):
    try:
        with open("data/answer.enc", "rb") as enc_file:
            encrypted_data = enc_file.read()
        cipher = Fernet(secret_key.encode())
        decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
        raw_data = json.loads(decrypted_data)

        # Convert to: dataset -> query_id -> {corpus_id: score}
        qrel_dict = defaultdict(lambda: defaultdict(dict))
        for dataset, records in raw_data.items():
            for item in records:
                qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
                qrel_dict[dataset][qid][cid] = score
        return qrel_dict
    except Exception as e:
        raise ValueError(f"❌ Failed to decrypt answer file: {str(e)}")

def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
    return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))

def ndcg_at_k(corpus_top_100_list, rel_dict, k):
    all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
    y_true = [rel_dict.get(item, 0) for item in all_items]
    y_score = [len(all_items) - i for i in range(len(all_items))]
    return ndcg_score([y_true], [y_score], k=k)

def evaluate(pred_data, qrel_dict):
    results = {}
    for dataset, queries in pred_data.items():
        if dataset not in qrel_dict:
            continue

        recall_1, ndcg_10, ndcg_100 = [], [], []
        for item in queries:
            qid = item["query_id"]
            corpus_top_100_list = item["corpus_top_100_list"].split(",")
            corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
            rel_dict = qrel_dict[dataset].get(qid, {})
            relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]

            recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
            ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
            ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))

        results[dataset] = {
            "Recall@1": round(np.mean(recall_1) * 100, 2),
            "NDCG@10": round(np.mean(ndcg_10) * 100, 2),
            "NDCG@100": round(np.mean(ndcg_100) * 100, 2),
        }

    return results

def process_json(file):
    try:
        pred_data = json.load(open(file))
    except Exception as e:
        return f"❌ Invalid JSON format: {str(e)}"

    secret_key = os.getenv("SECRET_KEY")
    if not secret_key:
        return "❌ SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space."

    try:
        qrel_dict = load_and_decrypt_qrel(secret_key)
    except Exception as e:
        return str(e)

    try:
        metrics = evaluate(pred_data, qrel_dict)
        return json.dumps(metrics, indent=2)
    except Exception as e:
        return f"❌ Error during evaluation: {str(e)}"

def main_gradio():
    example_json_html = (
        '<pre><code>{<br>'
        '&nbsp;&nbsp;"Google_WIT": [<br>'
        '&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>'
        '&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>'
        '&nbsp;&nbsp;],<br>'
        '&nbsp;&nbsp;"MSCOCO": [<br>'
        '&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "1", "corpus_top_100_list": "122, 35, 22, ..."}<br>'
        '&nbsp;&nbsp;],<br>'
        '&nbsp;&nbsp;"OVEN": [<br>'
        '&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "1", "corpus_top_100_list": "11, 15, 22, ..."}<br>'
        '&nbsp;&nbsp;],<br>'
        '&nbsp;&nbsp;"VisualNews": [<br>'
        '&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "1", "corpus_top_100_list": "101, 35, 77, ..."}<br>'
        '&nbsp;&nbsp;]<br>'
        '}</code></pre>'
    )

    gr.Interface(
        fn=process_json,
        inputs=gr.File(label="Upload Retrieval Result (JSON)"),
        outputs=gr.Textbox(label="Evaluation Results"),
        title="πŸ” Automated Evaluation of MixBench",
        description=(
            "Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>"
            "For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), "
            "we compute:<br>"
            "- <strong>Recall@1</strong><br>"
            "- <strong>NDCG@10</strong><br>"
            "- <strong>NDCG@100</strong><br><br>"
            "Expected input JSON format:<br><br>" + example_json_html +
            "<br>To find valid query IDs, see the "
            "<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>."
        )
    ).launch(share=True)

if __name__ == "__main__":
    main_gradio()