File size: 4,982 Bytes
cf62495 a2f941d cf62495 a2f941d cf62495 6eb346b cf62495 6eb346b cf62495 6eb346b cf62495 6eb346b cf62495 a2f941d cf62495 a2f941d cf62495 09df82b a2f941d 09df82b c52dc72 09df82b a2f941d 09df82b ec99498 cf62495 acc872d a2f941d acc872d a2f941d acc872d a2f941d acc872d cf62495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import json
import os
import numpy as np
from cryptography.fernet import Fernet
from collections import defaultdict
from sklearn.metrics import ndcg_score
def load_and_decrypt_qrel(secret_key):
try:
with open("data/answer.enc", "rb") as enc_file:
encrypted_data = enc_file.read()
cipher = Fernet(secret_key.encode())
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
raw_data = json.loads(decrypted_data)
# Convert to: dataset -> query_id -> {corpus_id: score}
qrel_dict = defaultdict(lambda: defaultdict(dict))
for dataset, records in raw_data.items():
for item in records:
qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
qrel_dict[dataset][qid][cid] = score
return qrel_dict
except Exception as e:
raise ValueError(f"β Failed to decrypt answer file: {str(e)}")
def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
def ndcg_at_k(corpus_top_100_list, rel_dict, k):
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
y_true = [rel_dict.get(item, 0) for item in all_items]
y_score = [len(all_items) - i for i in range(len(all_items))]
return ndcg_score([y_true], [y_score], k=k)
def evaluate(pred_data, qrel_dict):
results = {}
for dataset, queries in pred_data.items():
if dataset not in qrel_dict:
continue
recall_1, ndcg_10, ndcg_100 = [], [], []
for item in queries:
qid = item["query_id"]
corpus_top_100_list = item["corpus_top_100_list"].split(",")
corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
rel_dict = qrel_dict[dataset].get(qid, {})
relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))
results[dataset] = {
"Recall@1": round(np.mean(recall_1) * 100, 2),
"NDCG@10": round(np.mean(ndcg_10) * 100, 2),
"NDCG@100": round(np.mean(ndcg_100) * 100, 2),
}
return results
def process_json(file):
try:
pred_data = json.load(open(file))
except Exception as e:
return f"β Invalid JSON format: {str(e)}"
secret_key = os.getenv("SECRET_KEY")
if not secret_key:
return "β SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space."
try:
qrel_dict = load_and_decrypt_qrel(secret_key)
except Exception as e:
return str(e)
try:
metrics = evaluate(pred_data, qrel_dict)
return json.dumps(metrics, indent=2)
except Exception as e:
return f"β Error during evaluation: {str(e)}"
def main_gradio():
example_json_html = (
'<pre><code>{<br>'
' "Google_WIT": [<br>'
' {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>'
' {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>'
' ],<br>'
' "MSCOCO": [<br>'
' {"query_id": "1", "corpus_top_100_list": "122, 35, 22, ..."}<br>'
' ],<br>'
' "OVEN": [<br>'
' {"query_id": "1", "corpus_top_100_list": "11, 15, 22, ..."}<br>'
' ],<br>'
' "VisualNews": [<br>'
' {"query_id": "1", "corpus_top_100_list": "101, 35, 77, ..."}<br>'
' ]<br>'
'}</code></pre>'
)
gr.Interface(
fn=process_json,
inputs=gr.File(label="Upload Retrieval Result (JSON)"),
outputs=gr.Textbox(label="Evaluation Results"),
title="π Automated Evaluation of MixBench",
description=(
"Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>"
"For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), "
"we compute:<br>"
"- <strong>Recall@1</strong><br>"
"- <strong>NDCG@10</strong><br>"
"- <strong>NDCG@100</strong><br><br>"
"Expected input JSON format:<br><br>" + example_json_html +
"<br>To find valid query IDs, see the "
"<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>."
)
).launch(share=True)
if __name__ == "__main__":
main_gradio()
|