foxxy-hm commited on
Commit
95ee61c
·
1 Parent(s): 381a28c

Update src/models/predict_model.py

Browse files
Files changed (1) hide show
  1. src/models/predict_model.py +10 -9
src/models/predict_model.py CHANGED
@@ -5,15 +5,16 @@ from src.models.bm25_utils import BM25Gensim
5
  from src.models.qa_model import *
6
  from tqdm.auto import tqdm
7
  tqdm.pandas()
 
8
 
9
- df_wiki_windows = pd.read_csv("/home/user/app/src/data/processed/wikipedia_20220620_cleaned_v2.csv")
10
- df_wiki = pd.read_csv("/home/user/app/src/data/wikipedia_20220620_short.csv")
11
  df_wiki.title = df_wiki.title.apply(str)
12
 
13
- entity_dict = json.load(open("/home/user/app/src/data/processed/entities.json"))
14
  new_dict = dict()
15
  for key, val in entity_dict.items():
16
- val = val.replace("wiki/", "").replace("_", " ")
17
  entity_dict[key] = val
18
  key = preprocess(key)
19
  new_dict[key.lower()] = val
@@ -22,15 +23,15 @@ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.va
22
 
23
  qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
24
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
25
- pairwise_model_stage1.load_state_dict(torch.load("/home/user/app/src/models/pairwise_v2.bin"))
26
  pairwise_model_stage1.eval()
27
 
28
  pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
29
- pairwise_model_stage2.load_state_dict(torch.load("/home/user/app/src/models/pairwise_stage2_seed0.bin"))
30
 
31
- bm25_model_stage1 = BM25Gensim("/home/user/app/src/models/bm25_stage1/", entity_dict, title2idx)
32
- bm25_model_stage2_full = BM25Gensim("/home/user/app/src/models/bm25_stage2/full_text/", entity_dict, title2idx)
33
- bm25_model_stage2_title = BM25Gensim("/home/user/app/src/models/bm25_stage2/title/", entity_dict, title2idx)
34
 
35
  def get_answer_e2e(question):
36
  #Bm25 retrieval for top200 candidates
 
5
  from src.models.qa_model import *
6
  from tqdm.auto import tqdm
7
  tqdm.pandas()
8
+ from datasets import load_dataset
9
 
10
+ df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
11
+ df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_short.csv")["train"].to_pandas()
12
  df_wiki.title = df_wiki.title.apply(str)
13
 
14
+ entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
15
  new_dict = dict()
16
  for key, val in entity_dict.items():
17
+ val = val[0].replace("wiki/", "").replace("_", " ")
18
  entity_dict[key] = val
19
  key = preprocess(key)
20
  new_dict[key.lower()] = val
 
23
 
24
  qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
25
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
26
+ pairwise_model_stage1.load_state_dict(torch.load("/src/models/pairwise_v2.bin"))
27
  pairwise_model_stage1.eval()
28
 
29
  pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
30
+ pairwise_model_stage2.load_state_dict(torch.load("/src/models/pairwise_stage2_seed0.bin"))
31
 
32
+ bm25_model_stage1 = BM25Gensim("/src/models/bm25_stage1/", entity_dict, title2idx)
33
+ bm25_model_stage2_full = BM25Gensim("/src/models/bm25_stage2/full_text/", entity_dict, title2idx)
34
+ bm25_model_stage2_title = BM25Gensim("/src/models/bm25_stage2/title/", entity_dict, title2idx)
35
 
36
  def get_answer_e2e(question):
37
  #Bm25 retrieval for top200 candidates