|
import gradio as gr |
|
import json |
|
import os |
|
import numpy as np |
|
from cryptography.fernet import Fernet |
|
from collections import defaultdict |
|
from sklearn.metrics import ndcg_score |
|
|
|
def load_and_decrypt_qrel(secret_key): |
|
try: |
|
with open("data/answer.enc", "rb") as enc_file: |
|
encrypted_data = enc_file.read() |
|
cipher = Fernet(secret_key.encode()) |
|
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8") |
|
raw_data = json.loads(decrypted_data) |
|
|
|
|
|
qrel_dict = defaultdict(lambda: defaultdict(dict)) |
|
for dataset, records in raw_data.items(): |
|
for item in records: |
|
qid, cid, score = item["query_id"], item["corpus_id"], item["score"] |
|
qrel_dict[dataset][qid][cid] = score |
|
return qrel_dict |
|
except Exception as e: |
|
raise ValueError(f"Failed to decrypt answer file: {str(e)}") |
|
|
|
def recall_at_k(corpus_top_100_list, relevant_ids, k=1): |
|
return int(any(item in relevant_ids for item in corpus_top_100_list[:k])) |
|
|
|
def ndcg_at_k(corpus_top_100_list, rel_dict, k): |
|
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys()))) |
|
|
|
y_true = [rel_dict.get(item, 0) for item in all_items] |
|
|
|
y_score = [len(all_items) - i for i in range(len(all_items))] |
|
|
|
return ndcg_score([y_true], [y_score], k=k) |
|
|
|
def evaluate(pred_data, qrel_dict): |
|
results = {} |
|
for dataset, queries in pred_data.items(): |
|
if dataset not in qrel_dict: |
|
continue |
|
|
|
recall_1, ndcg_10, ndcg_100 = [], [], [] |
|
|
|
for item in queries: |
|
qid = item["query_id"] |
|
corpus_top_100_list = item["corpus_top_100_list"].split(",") |
|
corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()] |
|
rel_dict = qrel_dict[dataset].get(qid, {}) |
|
relevant_ids = [cid for cid, score in rel_dict.items() if score > 0] |
|
|
|
recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1)) |
|
ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10)) |
|
ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100)) |
|
|
|
results[dataset] = { |
|
"Recall@1": round(np.mean(recall_1) * 100, 2), |
|
"NDCG@10": round(np.mean(ndcg_10) * 100, 2), |
|
"NDCG@100": round(np.mean(ndcg_100) * 100, 2), |
|
} |
|
|
|
return results |
|
|
|
|
|
def process_json(file): |
|
try: |
|
pred_data = json.load(open(file)) |
|
except Exception as e: |
|
return f"Invalid JSON format: {str(e)}" |
|
|
|
try: |
|
secret_key = os.getenv("SECRET_KEY") |
|
qrel_dict = load_and_decrypt_qrel(secret_key) |
|
except Exception as e: |
|
return str(e) |
|
|
|
try: |
|
metrics = evaluate(pred_data, qrel_dict) |
|
return json.dumps(metrics, indent=2) |
|
except Exception as e: |
|
return f"Error during evaluation: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main_gradio(): |
|
example_json = '''{ |
|
"Google_WIT": [ |
|
{"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."}, |
|
{"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."} |
|
], |
|
"MSCOCO": [ |
|
{"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."}, |
|
{"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."} |
|
], |
|
"OVEN": [ |
|
{"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."} |
|
], |
|
"VisualNews": [ |
|
{"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."} |
|
] |
|
}''' |
|
|
|
formatted_example = f"<pre>{example_json}</pre>" |
|
|
|
gr.Interface( |
|
fn=process_json, |
|
inputs=gr.File(label="Upload Retrieval Result (JSON)"), |
|
outputs=gr.Textbox(label="Results"), |
|
title="Automated Evaluation of MixBench", |
|
description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.<br><br>" |
|
"Example input format:<br>" + formatted_example |
|
).launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
main_gradio() |
|
|