Shiyu Zhao
commited on
Commit
·
2c4b436
1
Parent(s):
1250c3d
Update space
Browse files- app.py +92 -2
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,6 +4,98 @@ import os
|
|
4 |
import re
|
5 |
from datetime import datetime
|
6 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
# Sample data based on your table (you'll need to update this with the full dataset)
|
@@ -58,7 +150,6 @@ data_human_generated = {
|
|
58 |
df_synthesized_full = pd.DataFrame(data_synthesized_full)
|
59 |
df_synthesized_10 = pd.DataFrame(data_synthesized_10)
|
60 |
df_human_generated = pd.DataFrame(data_human_generated)
|
61 |
-
|
62 |
def validate_email(email_str):
|
63 |
"""Validate email format(s)"""
|
64 |
emails = [e.strip() for e in email_str.split(';')]
|
@@ -247,7 +338,6 @@ def add_submission_form(demo):
|
|
247 |
],
|
248 |
outputs=result
|
249 |
)
|
250 |
-
|
251 |
def format_dataframe(df, dataset):
|
252 |
# Filter the dataframe for the selected dataset
|
253 |
columns = ['Method'] + [col for col in df.columns if dataset in col]
|
|
|
4 |
import re
|
5 |
from datetime import datetime
|
6 |
import json
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
12 |
+
|
13 |
+
from stark_qa import load_qa
|
14 |
+
from stark_qa.evaluator import Evaluator
|
15 |
+
|
16 |
+
|
17 |
+
def process_single_instance(args):
|
18 |
+
idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
|
19 |
+
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
20 |
+
|
21 |
+
try:
|
22 |
+
pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
|
23 |
+
except IndexError:
|
24 |
+
raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
|
25 |
+
except Exception as e:
|
26 |
+
raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
|
27 |
+
|
28 |
+
if isinstance(pred_rank, str):
|
29 |
+
try:
|
30 |
+
pred_rank = eval(pred_rank)
|
31 |
+
except SyntaxError as e:
|
32 |
+
raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
|
33 |
+
|
34 |
+
if not isinstance(pred_rank, list):
|
35 |
+
raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
|
36 |
+
|
37 |
+
pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
|
38 |
+
answer_ids = torch.LongTensor(answer_ids)
|
39 |
+
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
40 |
+
|
41 |
+
result["idx"], result["query_id"] = idx, query_id
|
42 |
+
return result
|
43 |
+
|
44 |
+
|
45 |
+
def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
|
46 |
+
candidate_ids_dict = {
|
47 |
+
'amazon': [i for i in range(957192)],
|
48 |
+
'mag': [i for i in range(1172724, 1872968)],
|
49 |
+
'prime': [i for i in range(129375)]
|
50 |
+
}
|
51 |
+
try:
|
52 |
+
eval_csv = pd.read_csv(csv_path)
|
53 |
+
if 'query_id' not in eval_csv.columns:
|
54 |
+
raise ValueError('No `query_id` column found in the submitted csv.')
|
55 |
+
if 'pred_rank' not in eval_csv.columns:
|
56 |
+
raise ValueError('No `pred_rank` column found in the submitted csv.')
|
57 |
+
|
58 |
+
eval_csv = eval_csv[['query_id', 'pred_rank']]
|
59 |
+
|
60 |
+
if dataset not in candidate_ids_dict:
|
61 |
+
raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
|
62 |
+
if split not in ['test', 'test-0.1', 'human_generated_eval']:
|
63 |
+
raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
|
64 |
+
|
65 |
+
evaluator = Evaluator(candidate_ids_dict[dataset])
|
66 |
+
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
67 |
+
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
68 |
+
split_idx = qa_dataset.get_idx_split()
|
69 |
+
all_indices = split_idx[split].tolist()
|
70 |
+
|
71 |
+
results_list = []
|
72 |
+
query_ids = []
|
73 |
+
|
74 |
+
# Prepare args for each worker
|
75 |
+
args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
|
76 |
+
|
77 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
78 |
+
futures = [executor.submit(process_single_instance, arg) for arg in args]
|
79 |
+
for future in tqdm(as_completed(futures), total=len(futures)):
|
80 |
+
result = future.result() # This will raise an error if the worker encountered one
|
81 |
+
results_list.append(result)
|
82 |
+
query_ids.append(result['query_id'])
|
83 |
+
|
84 |
+
# Concatenate results and compute final metrics
|
85 |
+
eval_csv = pd.concat([eval_csv, pd.DataFrame(results_list)], ignore_index=True)
|
86 |
+
final_results = {
|
87 |
+
metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
|
88 |
+
}
|
89 |
+
return final_results
|
90 |
+
|
91 |
+
except pd.errors.EmptyDataError:
|
92 |
+
return "Error: The CSV file is empty or could not be read. Please check the file and try again."
|
93 |
+
except FileNotFoundError:
|
94 |
+
return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
|
95 |
+
except Exception as error:
|
96 |
+
return f"{error}"
|
97 |
+
|
98 |
+
|
99 |
|
100 |
|
101 |
# Sample data based on your table (you'll need to update this with the full dataset)
|
|
|
150 |
df_synthesized_full = pd.DataFrame(data_synthesized_full)
|
151 |
df_synthesized_10 = pd.DataFrame(data_synthesized_10)
|
152 |
df_human_generated = pd.DataFrame(data_human_generated)
|
|
|
153 |
def validate_email(email_str):
|
154 |
"""Validate email format(s)"""
|
155 |
emails = [e.strip() for e in email_str.split(';')]
|
|
|
338 |
],
|
339 |
outputs=result
|
340 |
)
|
|
|
341 |
def format_dataframe(df, dataset):
|
342 |
# Filter the dataframe for the selected dataset
|
343 |
columns = ['Method'] + [col for col in df.columns if dataset in col]
|
requirements.txt
CHANGED
@@ -13,4 +13,5 @@ python-dateutil
|
|
13 |
tqdm
|
14 |
transformers
|
15 |
tokenizers>=0.15.0
|
16 |
-
sentencepiece
|
|
|
|
13 |
tqdm
|
14 |
transformers
|
15 |
tokenizers>=0.15.0
|
16 |
+
sentencepiece
|
17 |
+
stark_qa
|