Linhz commited on
Commit
3789e3a
·
verified ·
1 Parent(s): 8d196c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -4,6 +4,7 @@ from sentence_transformers import SentenceTransformer
4
  import pickle
5
  import re
6
  from transformers import pipeline
 
7
 
8
 
9
 
@@ -22,19 +23,21 @@ with open('articles.pkl', 'rb') as file:
22
 
23
  index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss")
24
 
 
 
25
  if 'model_embedding' not in st.session_state:
26
- st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder')
27
 
28
 
29
 
30
  # Replace this with your own checkpoint
31
  model_checkpoint = "model"
32
- question_answerer = pipeline("question-answering", model=model_checkpoint)
33
  def question_answering(question):
34
  print(question)
35
  query_sentence = [question]
36
  query_embedding = st.session_state.model_embedding.encode(query_sentence)
37
- k = 10
38
  D, I = index_loaded.search(query_embedding.astype('float32'), k) # D is distances, I is indices
39
  answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 256) for i in range(k)]
40
  best_answer = max(answer, key=lambda x: x['score'])
 
4
  import pickle
5
  import re
6
  from transformers import pipeline
7
+ import torch
8
 
9
 
10
 
 
23
 
24
  index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss")
25
 
26
+
27
+ device = 0 if torch.cuda.is_available() else -1
28
  if 'model_embedding' not in st.session_state:
29
+ st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder', device = device)
30
 
31
 
32
 
33
  # Replace this with your own checkpoint
34
  model_checkpoint = "model"
35
+ question_answerer = pipeline("question-answering", model=model_checkpoint, device = device)
36
  def question_answering(question):
37
  print(question)
38
  query_sentence = [question]
39
  query_embedding = st.session_state.model_embedding.encode(query_sentence)
40
+ k = 20
41
  D, I = index_loaded.search(query_embedding.astype('float32'), k) # D is distances, I is indices
42
  answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 256) for i in range(k)]
43
  best_answer = max(answer, key=lambda x: x['score'])