foxxy-hm commited on
Commit
242ba47
·
1 Parent(s): 38f77f3

Update src/models/predict_model.py

Browse files
Files changed (1) hide show
  1. src/models/predict_model.py +6 -6
src/models/predict_model.py CHANGED
@@ -21,17 +21,17 @@ for key, val in entity_dict.items():
21
  entity_dict.update(new_dict)
22
  title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
 
24
- qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["../../models/qa_model_robust.bin"], entity_dict)
25
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
26
- pairwise_model_stage1.load_state_dict(torch.load("../../models/pairwise_v2.bin"))
27
  pairwise_model_stage1.eval()
28
 
29
  pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
30
- pairwise_model_stage2.load_state_dict(torch.load("../../models/pairwise_stage2_seed0.bin"))
31
 
32
- bm25_model_stage1 = BM25Gensim("../../models/bm25_stage1/", entity_dict, title2idx)
33
- bm25_model_stage2_full = BM25Gensim("../../models/bm25_stage2/full_text/", entity_dict, title2idx)
34
- bm25_model_stage2_title = BM25Gensim("../../models/bm25_stage2/title/", entity_dict, title2idx)
35
 
36
  def get_answer_e2e(question):
37
  #Bm25 retrieval for top200 candidates
 
21
  entity_dict.update(new_dict)
22
  title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
 
24
+ qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["models/qa_model_robust.bin"], entity_dict)
25
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
26
+ pairwise_model_stage1.load_state_dict(torch.load("models/pairwise_v2.bin"))
27
  pairwise_model_stage1.eval()
28
 
29
  pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
30
+ pairwise_model_stage2.load_state_dict(torch.load("models/pairwise_stage2_seed0.bin"))
31
 
32
+ bm25_model_stage1 = BM25Gensim("models/bm25_stage1/", entity_dict, title2idx)
33
+ bm25_model_stage2_full = BM25Gensim("models/bm25_stage2/full_text/", entity_dict, title2idx)
34
+ bm25_model_stage2_title = BM25Gensim("models/bm25_stage2/title/", entity_dict, title2idx)
35
 
36
  def get_answer_e2e(question):
37
  #Bm25 retrieval for top200 candidates