Shiyu Zhao commited on
Commit
2c4b436
·
1 Parent(s): 1250c3d

Update space

Browse files
Files changed (2) hide show
  1. app.py +92 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,98 @@ import os
4
  import re
5
  from datetime import datetime
6
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  # Sample data based on your table (you'll need to update this with the full dataset)
@@ -58,7 +150,6 @@ data_human_generated = {
58
  df_synthesized_full = pd.DataFrame(data_synthesized_full)
59
  df_synthesized_10 = pd.DataFrame(data_synthesized_10)
60
  df_human_generated = pd.DataFrame(data_human_generated)
61
-
62
  def validate_email(email_str):
63
  """Validate email format(s)"""
64
  emails = [e.strip() for e in email_str.split(';')]
@@ -247,7 +338,6 @@ def add_submission_form(demo):
247
  ],
248
  outputs=result
249
  )
250
-
251
  def format_dataframe(df, dataset):
252
  # Filter the dataframe for the selected dataset
253
  columns = ['Method'] + [col for col in df.columns if dataset in col]
 
4
  import re
5
  from datetime import datetime
6
  import json
7
+ import pandas as pd
8
+ import torch
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from concurrent.futures import ProcessPoolExecutor, as_completed
12
+
13
+ from stark_qa import load_qa
14
+ from stark_qa.evaluator import Evaluator
15
+
16
+
17
+ def process_single_instance(args):
18
+ idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
19
+ query, query_id, answer_ids, meta_info = qa_dataset[idx]
20
+
21
+ try:
22
+ pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
23
+ except IndexError:
24
+ raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
25
+ except Exception as e:
26
+ raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
27
+
28
+ if isinstance(pred_rank, str):
29
+ try:
30
+ pred_rank = eval(pred_rank)
31
+ except SyntaxError as e:
32
+ raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
33
+
34
+ if not isinstance(pred_rank, list):
35
+ raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
36
+
37
+ pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
38
+ answer_ids = torch.LongTensor(answer_ids)
39
+ result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
40
+
41
+ result["idx"], result["query_id"] = idx, query_id
42
+ return result
43
+
44
+
45
+ def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
46
+ candidate_ids_dict = {
47
+ 'amazon': [i for i in range(957192)],
48
+ 'mag': [i for i in range(1172724, 1872968)],
49
+ 'prime': [i for i in range(129375)]
50
+ }
51
+ try:
52
+ eval_csv = pd.read_csv(csv_path)
53
+ if 'query_id' not in eval_csv.columns:
54
+ raise ValueError('No `query_id` column found in the submitted csv.')
55
+ if 'pred_rank' not in eval_csv.columns:
56
+ raise ValueError('No `pred_rank` column found in the submitted csv.')
57
+
58
+ eval_csv = eval_csv[['query_id', 'pred_rank']]
59
+
60
+ if dataset not in candidate_ids_dict:
61
+ raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
62
+ if split not in ['test', 'test-0.1', 'human_generated_eval']:
63
+ raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
64
+
65
+ evaluator = Evaluator(candidate_ids_dict[dataset])
66
+ eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
67
+ qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
68
+ split_idx = qa_dataset.get_idx_split()
69
+ all_indices = split_idx[split].tolist()
70
+
71
+ results_list = []
72
+ query_ids = []
73
+
74
+ # Prepare args for each worker
75
+ args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
76
+
77
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
78
+ futures = [executor.submit(process_single_instance, arg) for arg in args]
79
+ for future in tqdm(as_completed(futures), total=len(futures)):
80
+ result = future.result() # This will raise an error if the worker encountered one
81
+ results_list.append(result)
82
+ query_ids.append(result['query_id'])
83
+
84
+ # Concatenate results and compute final metrics
85
+ eval_csv = pd.concat([eval_csv, pd.DataFrame(results_list)], ignore_index=True)
86
+ final_results = {
87
+ metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
88
+ }
89
+ return final_results
90
+
91
+ except pd.errors.EmptyDataError:
92
+ return "Error: The CSV file is empty or could not be read. Please check the file and try again."
93
+ except FileNotFoundError:
94
+ return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
95
+ except Exception as error:
96
+ return f"{error}"
97
+
98
+
99
 
100
 
101
  # Sample data based on your table (you'll need to update this with the full dataset)
 
150
  df_synthesized_full = pd.DataFrame(data_synthesized_full)
151
  df_synthesized_10 = pd.DataFrame(data_synthesized_10)
152
  df_human_generated = pd.DataFrame(data_human_generated)
 
153
  def validate_email(email_str):
154
  """Validate email format(s)"""
155
  emails = [e.strip() for e in email_str.split(';')]
 
338
  ],
339
  outputs=result
340
  )
 
341
  def format_dataframe(df, dataset):
342
  # Filter the dataframe for the selected dataset
343
  columns = ['Method'] + [col for col in df.columns if dataset in col]
requirements.txt CHANGED
@@ -13,4 +13,5 @@ python-dateutil
13
  tqdm
14
  transformers
15
  tokenizers>=0.15.0
16
- sentencepiece
 
 
13
  tqdm
14
  transformers
15
  tokenizers>=0.15.0
16
+ sentencepiece
17
+ stark_qa