import gradio as gr
import json
import os
import numpy as np
from cryptography.fernet import Fernet
from collections import defaultdict
from sklearn.metrics import ndcg_score
def load_and_decrypt_qrel(secret_key):
try:
with open("data/answer.enc", "rb") as enc_file:
encrypted_data = enc_file.read()
cipher = Fernet(secret_key.encode())
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
raw_data = json.loads(decrypted_data)
# qrel_dict: dataset -> query_id -> {corpus_id: score}
qrel_dict = defaultdict(lambda: defaultdict(dict))
for dataset, records in raw_data.items():
for item in records:
qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
qrel_dict[dataset][qid][cid] = score
return qrel_dict
except Exception as e:
raise ValueError(f"Failed to decrypt answer file: {str(e)}")
def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
def ndcg_at_k(corpus_top_100_list, rel_dict, k):
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
y_true = [rel_dict.get(item, 0) for item in all_items]
y_score = [len(all_items) - i for i in range(len(all_items))]
return ndcg_score([y_true], [y_score], k=k)
def evaluate(pred_data, qrel_dict):
results = {}
for dataset, queries in pred_data.items():
if dataset not in qrel_dict:
continue
recall_1, ndcg_10, ndcg_100 = [], [], []
for item in queries:
qid = item["query_id"]
corpus_top_100_list = item["corpus_top_100_list"].split(",")
corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
rel_dict = qrel_dict[dataset].get(qid, {})
relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))
results[dataset] = {
"Recall@1": round(np.mean(recall_1) * 100, 2),
"NDCG@10": round(np.mean(ndcg_10) * 100, 2),
"NDCG@100": round(np.mean(ndcg_100) * 100, 2),
}
return results
# ==== Gradio Wrapper ====
def process_json(file):
try:
pred_data = json.load(open(file))
except Exception as e:
return f"Invalid JSON format: {str(e)}"
try:
secret_key = os.getenv("SECRET_KEY")
qrel_dict = load_and_decrypt_qrel(secret_key)
except Exception as e:
return str(e)
try:
metrics = evaluate(pred_data, qrel_dict)
return json.dumps(metrics, indent=2)
except Exception as e:
return f"Error during evaluation: {str(e)}"
# ==== Launch Gradio App ====
# def main_gradio():
# example_json = '''{
# "Google_WIT": [
# {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
# {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
# ],
# "MSCOCO": [
# {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
# {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
# ]
# "OVEN": [
# {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
# ]
# "VisualNews": [
# {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
# ]
# }'''
# gr.Interface(
# fn=process_json,
# inputs=gr.File(label="Upload Retrieval Result (JSON)"),
# outputs=gr.Textbox(label="Results"),
# title="Automated Evaluation of MixBench",
# description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
# ).launch(share=True)
def main_gradio():
example_json = '''{
"Google_WIT": [
{"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
{"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
],
"MSCOCO": [
{"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
{"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
],
"OVEN": [
{"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
],
"VisualNews": [
{"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
]
}'''
# Convert \n to
and wrap example in
to preserve indentation formatted_example = f"{example_json}" gr.Interface( fn=process_json, inputs=gr.File(label="Upload Retrieval Result (JSON)"), outputs=gr.Textbox(label="Results"), title="Automated Evaluation of MixBench", description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.
" "Example input format:
" + formatted_example ).launch(share=True) if __name__ == "__main__": main_gradio()