File size: 4,663 Bytes
3a940c5
 
 
 
 
 
3789e3a
3a940c5
 
b4f0029
3a940c5
8ebe27a
 
b4f0029
 
 
 
 
 
8d196c9
 
 
 
 
 
 
3a940c5
 
 
 
 
3789e3a
 
3a940c5
05712e4
3a940c5
 
 
 
 
3789e3a
3a940c5
 
 
 
f1ec914
3a940c5
ebadfbf
 
6ae05ec
9d0e5e0
3a940c5
8514b8b
9d0e5e0
 
3a940c5
a0635c6
 
3a940c5
a0635c6
 
 
3a940c5
 
 
 
e3d1d56
3a940c5
a0635c6
 
 
 
 
 
 
3a940c5
a0635c6
 
 
 
 
 
 
 
 
e4915f4
a0635c6
 
 
e4915f4
a0635c6
 
 
 
 
 
 
 
 
 
e4915f4
a0635c6
 
 
 
 
 
 
 
 
e4915f4
a0635c6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import faiss
from sentence_transformers import SentenceTransformer
import pickle
import re
from transformers import pipeline
import torch


st.set_page_config(page_title = "Vietnamese Legal Question Answering System", page_icon= "./app/static/Law.png", layout="centered", initial_sidebar_state="collapsed")

with open("./static/styles.css") as f:
    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

st.markdown(f"""
    <div class=logo_area>
        <img src="./app/static/Law.png"/>
    </div>
    """, unsafe_allow_html=True)
st.markdown(
    """
    <h1 style="text-align: center;">Vietnamese Legal Question Answering System</h1>
    """, 
    unsafe_allow_html=True
)

with open('articles.pkl', 'rb') as file:
    articles = pickle.load(file)

index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss")


device = 0 if torch.cuda.is_available() else -1
if 'model_embedding' not in st.session_state:
    st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder', device = f"cuda:{device}")



# Replace this with your own checkpoint
model_checkpoint = "model"
question_answerer = pipeline("question-answering", model=model_checkpoint, device = device)
def question_answering(question):
    print(question)
    query_sentence = [question]
    query_embedding = st.session_state.model_embedding.encode(query_sentence)
    k = 200
    D, I = index_loaded.search(query_embedding.astype('float32'), k)  # D is distances, I is indices
    answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 512) for i in range(k)]
    best_answer = max(answer, key=lambda x: x['score'])
    print(best_answer)
    if best_answer['score'] > 0.7:
        return best_answer['answer']
    elif best_answer['score'] > 0.3:
        return f"Tôi không chắc lắm nhưng có lẽ câu trả lời là: \n{best_answer['answer']}"
    return f"Xin lỗi tôi không biết câu trả lời cho câu hỏi này, vui lòng hỏi lại câu hỏi khác"

# if "messages" not in st.session_state:
#     st.session_state.messages = []

# for message in st.session_state.messages:
#     with st.chat_message(message["role"]):
#         st.markdown(message["content"])


def clean_answer(s):
    # Sử dụng regex để loại bỏ tất cả các ký tự đặc biệt ở cuối chuỗi
    return re.sub(r'[^aAàÀảẢáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0-9]+$', '', s)
    
# if prompt := st.chat_input("What is up?"):
#         st.session_state.messages.append({"role": "user", "content": prompt})
#         with st.chat_message("user"):
#             st.markdown(prompt)
#         response = clean_answer(question_answering(prompt))
#         with st.chat_message("assistant"):
#             st.markdown(response)
        
#         st.session_state.messages.append({"role": "assistant", "content": response})

if 'messages' not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    if message['role'] == 'assistant':
        avatar_class = "assistant-avatar"
        message_class = "assistant-message"
        avatar = './app/static/AI.png'
    else:
        avatar_class = "user-avatar"
        message_class = "user-message"
        avatar = './app/static/human.jpg'
    st.markdown(f"""
    <div class="{message_class}">
        <img src="{avatar}" class="{avatar_class}" />
        <div class="stMarkdown">{message['content']}</div>
    </div>
    """, unsafe_allow_html=True)

if prompt := st.chat_input(placeholder='Xin chào, tôi có thể giúp được gì cho bạn?'):
    st.markdown(f"""
    <div class="user-message">
        <img src="./app/static/human.jpg" class="user-avatar" />
        <div class="stMarkdown">{prompt}</div>
    </div>
    """, unsafe_allow_html=True)
    st.session_state.messages.append({'role': 'user', 'content': prompt})
    
    respond = clean_answer(question_answering(prompt))

    st.markdown(f"""
    <div class="assistant-message">
        <img src="./app/static/AI.png" class="assistant-avatar" />
        <div class="stMarkdown">{respond}</div>
    </div>
    """, unsafe_allow_html=True)
    st.session_state.messages.append({'role': 'assistant', 'content': respond})