File size: 2,861 Bytes
3a940c5
 
 
 
 
 
3789e3a
3a940c5
 
 
 
0965d00
3a940c5
8d196c9
 
 
 
 
 
 
3a940c5
 
 
 
 
3789e3a
 
3a940c5
3789e3a
3a940c5
 
 
 
 
3789e3a
3a940c5
 
 
 
3789e3a
3a940c5
e3d1d56
3a940c5
 
 
 
8f16701
3a940c5
 
 
 
 
 
 
 
 
 
 
e3d1d56
3a940c5
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
import faiss
from sentence_transformers import SentenceTransformer
import pickle
import re
from transformers import pipeline
import torch




st.set_page_config(page_title = "Vietnamese Legal Question Answering System", page_icon= "🐧", layout="centered", initial_sidebar_state="collapsed")

st.markdown(
    """
    <h1 style="text-align: center;">Vietnamese Legal Question Answering System</h1>
    """, 
    unsafe_allow_html=True
)

with open('articles.pkl', 'rb') as file:
    articles = pickle.load(file)

index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss")


device = 0 if torch.cuda.is_available() else -1
if 'model_embedding' not in st.session_state:
    st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder', device = device)



# Replace this with your own checkpoint
model_checkpoint = "model"
question_answerer = pipeline("question-answering", model=model_checkpoint, device = device)
def question_answering(question):
    print(question)
    query_sentence = [question]
    query_embedding = st.session_state.model_embedding.encode(query_sentence)
    k = 20
    D, I = index_loaded.search(query_embedding.astype('float32'), k)  # D is distances, I is indices
    answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 256) for i in range(k)]
    best_answer = max(answer, key=lambda x: x['score'])
    print(best_answer['answer'])
    if best_answer['score'] > 0.5:
        return best_answer['answer']
    return f"Tôi không chắc lắm nhưng có lẽ câu trả lời là: \n{best_answer['answer']}"

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])


def clean_answer(s):
    # Sử dụng regex để loại bỏ tất cả các ký tự đặc biệt ở cuối chuỗi
    return re.sub(r'[^aAàÀảẢáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0-9]+$', '', s)
    
if prompt := st.chat_input("What is up?"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        response = clean_answer(question_answering(prompt))
        with st.chat_message("assistant"):
            st.markdown(response)
        
        st.session_state.messages.append({"role": "assistant", "content": response})