Spaces:
Build error
Build error
Update src/models/predict_model.py
Browse files- src/models/predict_model.py +10 -9
src/models/predict_model.py
CHANGED
@@ -5,15 +5,16 @@ from src.models.bm25_utils import BM25Gensim
|
|
5 |
from src.models.qa_model import *
|
6 |
from tqdm.auto import tqdm
|
7 |
tqdm.pandas()
|
|
|
8 |
|
9 |
-
df_wiki_windows =
|
10 |
-
df_wiki =
|
11 |
df_wiki.title = df_wiki.title.apply(str)
|
12 |
|
13 |
-
entity_dict =
|
14 |
new_dict = dict()
|
15 |
for key, val in entity_dict.items():
|
16 |
-
val = val.replace("wiki/", "").replace("_", " ")
|
17 |
entity_dict[key] = val
|
18 |
key = preprocess(key)
|
19 |
new_dict[key.lower()] = val
|
@@ -22,15 +23,15 @@ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.va
|
|
22 |
|
23 |
qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
|
24 |
pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
|
25 |
-
pairwise_model_stage1.load_state_dict(torch.load("/
|
26 |
pairwise_model_stage1.eval()
|
27 |
|
28 |
pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
|
29 |
-
pairwise_model_stage2.load_state_dict(torch.load("/
|
30 |
|
31 |
-
bm25_model_stage1 = BM25Gensim("/
|
32 |
-
bm25_model_stage2_full = BM25Gensim("/
|
33 |
-
bm25_model_stage2_title = BM25Gensim("/
|
34 |
|
35 |
def get_answer_e2e(question):
|
36 |
#Bm25 retrieval for top200 candidates
|
|
|
5 |
from src.models.qa_model import *
|
6 |
from tqdm.auto import tqdm
|
7 |
tqdm.pandas()
|
8 |
+
from datasets import load_dataset
|
9 |
|
10 |
+
df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
|
11 |
+
df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_short.csv")["train"].to_pandas()
|
12 |
df_wiki.title = df_wiki.title.apply(str)
|
13 |
|
14 |
+
entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
|
15 |
new_dict = dict()
|
16 |
for key, val in entity_dict.items():
|
17 |
+
val = val[0].replace("wiki/", "").replace("_", " ")
|
18 |
entity_dict[key] = val
|
19 |
key = preprocess(key)
|
20 |
new_dict[key.lower()] = val
|
|
|
23 |
|
24 |
qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["src/models/qa_model_robust.bin"], entity_dict)
|
25 |
pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
|
26 |
+
pairwise_model_stage1.load_state_dict(torch.load("/src/models/pairwise_v2.bin"))
|
27 |
pairwise_model_stage1.eval()
|
28 |
|
29 |
pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base").half()
|
30 |
+
pairwise_model_stage2.load_state_dict(torch.load("/src/models/pairwise_stage2_seed0.bin"))
|
31 |
|
32 |
+
bm25_model_stage1 = BM25Gensim("/src/models/bm25_stage1/", entity_dict, title2idx)
|
33 |
+
bm25_model_stage2_full = BM25Gensim("/src/models/bm25_stage2/full_text/", entity_dict, title2idx)
|
34 |
+
bm25_model_stage2_title = BM25Gensim("/src/models/bm25_stage2/title/", entity_dict, title2idx)
|
35 |
|
36 |
def get_answer_e2e(question):
|
37 |
#Bm25 retrieval for top200 candidates
|