foxxy-hm commited on
Commit
50d2c82
·
1 Parent(s): 040aace

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -1
app.py CHANGED
@@ -1,5 +1,83 @@
1
  import streamlit as st
2
- from src.models.predict_model import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  with st.sidebar:
5
  st.write("# 🤖 Language Models")
 
1
  import streamlit as st
2
+ # from src.models.predict_model import *
3
+
4
+ from src.models.pairwise_model import *
5
+ from src.features.text_utils import *
6
+ import regex as re
7
+ from src.models.bm25_utils import BM25Gensim
8
+ from src.models.qa_model import *
9
+ from tqdm.auto import tqdm
10
+ tqdm.pandas()
11
+ from datasets import load_dataset
12
+
13
+ df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
14
+ df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
15
+ df_wiki.title = df_wiki.title.apply(str)
16
+
17
+ entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
18
+ new_dict = dict()
19
+ for key, val in entity_dict.items():
20
+ val = val[0].replace("wiki/", "").replace("_", " ")
21
+ entity_dict[key] = val
22
+ key = preprocess(key)
23
+ new_dict[key.lower()] = val
24
+ entity_dict.update(new_dict)
25
+ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
26
+
27
+ qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["models/qa_model_robust.bin"], entity_dict)
28
+ pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
29
+ pairwise_model_stage1.load_state_dict(torch.load("models/pairwise_v2.bin", map_location=torch.device('cpu')))
30
+ pairwise_model_stage1.eval()
31
+
32
+ pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
33
+ pairwise_model_stage2.load_state_dict(torch.load("models/pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
34
+
35
+ bm25_model_stage1 = BM25Gensim("models/bm25_stage1/", entity_dict, title2idx)
36
+ bm25_model_stage2_full = BM25Gensim("models/bm25_stage2/full_text/", entity_dict, title2idx)
37
+ bm25_model_stage2_title = BM25Gensim("models/bm25_stage2/title/", entity_dict, title2idx)
38
+
39
+ def get_answer_e2e(question):
40
+ #Bm25 retrieval for top200 candidates
41
+ query = preprocess(question).lower()
42
+ top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
43
+ titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
44
+ texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
45
+
46
+ #Reranking with pairwise model for top10
47
+ question = preprocess(question)
48
+ ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
49
+ ranking_scores = ranking_preds * bm25_scores
50
+
51
+ #Question answering
52
+ best_idxs = np.argsort(ranking_scores)[-10:]
53
+ ranking_scores = np.array(ranking_scores)[best_idxs]
54
+ texts = np.array(texts)[best_idxs]
55
+ best_answer = qa_model(question, texts, ranking_scores)
56
+ if best_answer is None:
57
+ return "Chịu"
58
+ bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
59
+
60
+ #Entity mapping
61
+ if not check_number(bm25_answer):
62
+ bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
63
+ bm25_question_answer = bm25_question + " " + bm25_answer
64
+ candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
65
+ titles = [df_wiki.title.values[i] for i in candidates]
66
+ texts = [df_wiki.text.values[i] for i in candidates]
67
+ ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
68
+ if ranking_preds.max() >= 0.1:
69
+ final_answer = titles[ranking_preds.argmax()]
70
+ else:
71
+ candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
72
+ titles = [df_wiki.title.values[i] for i in candidates] + titles
73
+ texts = [df_wiki.text.values[i] for i in candidates] + texts
74
+ ranking_preds = np.concatenate(
75
+ [pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
76
+ final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
77
+ else:
78
+ final_answer = bm25_answer.lower()
79
+ return final_answer
80
+
81
 
82
  with st.sidebar:
83
  st.write("# 🤖 Language Models")