foxxy-hm commited on
Commit
f62b65f
·
1 Parent(s): 95ee61c

Update src/models/predict_model.py

Browse files
Files changed (1) hide show
  1. src/models/predict_model.py +2 -2
src/models/predict_model.py CHANGED
@@ -8,7 +8,7 @@ tqdm.pandas()
8
  from datasets import load_dataset
9
 
10
  df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
11
- df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_short.csv")["train"].to_pandas()
12
  df_wiki.title = df_wiki.title.apply(str)
13
 
14
  entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
@@ -21,7 +21,7 @@ for key, val in entity_dict.items():
21
  entity_dict.update(new_dict)
22
  title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
 
24
- qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
25
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
26
  pairwise_model_stage1.load_state_dict(torch.load("/src/models/pairwise_v2.bin"))
27
  pairwise_model_stage1.eval()
 
8
  from datasets import load_dataset
9
 
10
  df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
11
+ df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
12
  df_wiki.title = df_wiki.title.apply(str)
13
 
14
  entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
 
21
  entity_dict.update(new_dict)
22
  title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
 
24
+ qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["/src/models/qa_model_robust.bin"], entity_dict)
25
  pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
26
  pairwise_model_stage1.load_state_dict(torch.load("/src/models/pairwise_v2.bin"))
27
  pairwise_model_stage1.eval()