foxxy-hm commited on
Commit
381a28c
·
1 Parent(s): b2ed4d2

Update src/models/predict_model.py

Browse files
Files changed (1) hide show
  1. src/models/predict_model.py +8 -8
src/models/predict_model.py CHANGED
@@ -6,11 +6,11 @@ from src.models.qa_model import *
6
  from tqdm.auto import tqdm
7
  tqdm.pandas()
8
 
9
- df_wiki_windows = pd.read_csv("src/data/processed/wikipedia_20220620_cleaned_v2.csv")
10
- df_wiki = pd.read_csv("src/data/wikipedia_20220620_short.csv")
11
  df_wiki.title = df_wiki.title.apply(str)
12
 
13
- entity_dict = json.load(open("src/data/processed/entities.json"))
14
  new_dict = dict()
15
  for key, val in entity_dict.items():
16
  val = val.replace("wiki/", "").replace("_", " ")
@@ -22,15 +22,15 @@ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.va
22
 
23
  qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
24
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
25
- pairwise_model_stage1.load_state_dict(torch.load("src/models/pairwise_v2.bin"))
26
  pairwise_model_stage1.eval()
27
 
28
  pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
29
- pairwise_model_stage2.load_state_dict(torch.load("src/models/pairwise_stage2_seed0.bin"))
30
 
31
- bm25_model_stage1 = BM25Gensim("src/models/bm25_stage1/", entity_dict, title2idx)
32
- bm25_model_stage2_full = BM25Gensim("src/models/bm25_stage2/full_text/", entity_dict, title2idx)
33
- bm25_model_stage2_title = BM25Gensim("src/models/bm25_stage2/title/", entity_dict, title2idx)
34
 
35
  def get_answer_e2e(question):
36
  #Bm25 retrieval for top200 candidates
 
6
  from tqdm.auto import tqdm
7
  tqdm.pandas()
8
 
9
+ df_wiki_windows = pd.read_csv("/home/user/app/src/data/processed/wikipedia_20220620_cleaned_v2.csv")
10
+ df_wiki = pd.read_csv("/home/user/app/src/data/wikipedia_20220620_short.csv")
11
  df_wiki.title = df_wiki.title.apply(str)
12
 
13
+ entity_dict = json.load(open("/home/user/app/src/data/processed/entities.json"))
14
  new_dict = dict()
15
  for key, val in entity_dict.items():
16
  val = val.replace("wiki/", "").replace("_", " ")
 
22
 
23
  qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
24
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
25
+ pairwise_model_stage1.load_state_dict(torch.load("/home/user/app/src/models/pairwise_v2.bin"))
26
  pairwise_model_stage1.eval()
27
 
28
  pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
29
+ pairwise_model_stage2.load_state_dict(torch.load("/home/user/app/src/models/pairwise_stage2_seed0.bin"))
30
 
31
+ bm25_model_stage1 = BM25Gensim("/home/user/app/src/models/bm25_stage1/", entity_dict, title2idx)
32
+ bm25_model_stage2_full = BM25Gensim("/home/user/app/src/models/bm25_stage2/full_text/", entity_dict, title2idx)
33
+ bm25_model_stage2_title = BM25Gensim("/home/user/app/src/models/bm25_stage2/title/", entity_dict, title2idx)
34
 
35
  def get_answer_e2e(question):
36
  #Bm25 retrieval for top200 candidates