mixed-modality-search's picture
update
ec99498
raw
history blame
5.02 kB
import gradio as gr
import json
import os
import numpy as np
from cryptography.fernet import Fernet
from collections import defaultdict
from sklearn.metrics import ndcg_score
def load_and_decrypt_qrel(secret_key):
try:
with open("data/answer.enc", "rb") as enc_file:
encrypted_data = enc_file.read()
cipher = Fernet(secret_key.encode())
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
raw_data = json.loads(decrypted_data)
# qrel_dict: dataset -> query_id -> {corpus_id: score}
qrel_dict = defaultdict(lambda: defaultdict(dict))
for dataset, records in raw_data.items():
for item in records:
qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
qrel_dict[dataset][qid][cid] = score
return qrel_dict
except Exception as e:
raise ValueError(f"Failed to decrypt answer file: {str(e)}")
def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
def ndcg_at_k(corpus_top_100_list, rel_dict, k):
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
y_true = [rel_dict.get(item, 0) for item in all_items]
y_score = [len(all_items) - i for i in range(len(all_items))]
return ndcg_score([y_true], [y_score], k=k)
def evaluate(pred_data, qrel_dict):
results = {}
for dataset, queries in pred_data.items():
if dataset not in qrel_dict:
continue
recall_1, ndcg_10, ndcg_100 = [], [], []
for item in queries:
qid = item["query_id"]
corpus_top_100_list = item["corpus_top_100_list"].split(",")
corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
rel_dict = qrel_dict[dataset].get(qid, {})
relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))
results[dataset] = {
"Recall@1": round(np.mean(recall_1) * 100, 2),
"NDCG@10": round(np.mean(ndcg_10) * 100, 2),
"NDCG@100": round(np.mean(ndcg_100) * 100, 2),
}
return results
# ==== Gradio Wrapper ====
def process_json(file):
try:
pred_data = json.load(open(file))
except Exception as e:
return f"Invalid JSON format: {str(e)}"
try:
secret_key = os.getenv("SECRET_KEY")
qrel_dict = load_and_decrypt_qrel(secret_key)
except Exception as e:
return str(e)
try:
metrics = evaluate(pred_data, qrel_dict)
return json.dumps(metrics, indent=2)
except Exception as e:
return f"Error during evaluation: {str(e)}"
# ==== Launch Gradio App ====
# def main_gradio():
# example_json = '''{
# "Google_WIT": [
# {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
# {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
# ],
# "MSCOCO": [
# {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
# {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
# ]
# "OVEN": [
# {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
# ]
# "VisualNews": [
# {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
# ]
# }'''
# gr.Interface(
# fn=process_json,
# inputs=gr.File(label="Upload Retrieval Result (JSON)"),
# outputs=gr.Textbox(label="Results"),
# title="Automated Evaluation of MixBench",
# description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
# ).launch(share=True)
def main_gradio():
example_json = '''{
"Google_WIT": [
{"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
{"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
],
"MSCOCO": [
{"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
{"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
],
"OVEN": [
{"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
],
"VisualNews": [
{"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
]
}'''
# Convert \n to <br> and wrap example in <pre> to preserve indentation
formatted_example = f"<pre>{example_json}</pre>"
gr.Interface(
fn=process_json,
inputs=gr.File(label="Upload Retrieval Result (JSON)"),
outputs=gr.Textbox(label="Results"),
title="Automated Evaluation of MixBench",
description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.<br><br>"
"Example input format:<br>" + formatted_example
).launch(share=True)
if __name__ == "__main__":
main_gradio()