File size: 4,359 Bytes
3a940c5
 
 
 
 
 
3789e3a
3a940c5
 
0965d00
3a940c5
8ebe27a
 
 
8d196c9
 
 
 
 
 
 
3a940c5
 
 
 
 
3789e3a
 
3a940c5
05712e4
3a940c5
 
 
 
 
3789e3a
3a940c5
 
 
 
efd6f37
3a940c5
ebadfbf
 
6ae05ec
3a940c5
 
8f16701
3a940c5
a0635c6
 
3a940c5
a0635c6
 
 
3a940c5
 
 
 
e3d1d56
3a940c5
a0635c6
 
 
 
 
 
 
3a940c5
a0635c6
 
 
 
 
 
 
 
 
e4915f4
a0635c6
 
 
e4915f4
a0635c6
 
 
 
 
 
 
 
 
 
e4915f4
a0635c6
 
 
 
 
 
 
 
 
e4915f4
a0635c6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import faiss
from sentence_transformers import SentenceTransformer
import pickle
import re
from transformers import pipeline
import torch


st.set_page_config(page_title = "Vietnamese Legal Question Answering System", page_icon= "🐧", layout="centered", initial_sidebar_state="collapsed")

with open("./static/styles.css") as f:
    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
    
st.markdown(
    """
    <h1 style="text-align: center;">Vietnamese Legal Question Answering System</h1>
    """, 
    unsafe_allow_html=True
)

with open('articles.pkl', 'rb') as file:
    articles = pickle.load(file)

index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss")


device = 0 if torch.cuda.is_available() else -1
if 'model_embedding' not in st.session_state:
    st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder', device = f"cuda:{device}")



# Replace this with your own checkpoint
model_checkpoint = "model"
question_answerer = pipeline("question-answering", model=model_checkpoint, device = device)
def question_answering(question):
    print(question)
    query_sentence = [question]
    query_embedding = st.session_state.model_embedding.encode(query_sentence)
    k = 1
    D, I = index_loaded.search(query_embedding.astype('float32'), k)  # D is distances, I is indices
    answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 512) for i in range(k)]
    best_answer = max(answer, key=lambda x: x['score'])
    print(best_answer)
    if best_answer['score'] > 0.5:
        return best_answer['answer']
    return f"Tôi không chắc lắm nhưng có lẽ câu trả lời là: \n{best_answer['answer']}"

# if "messages" not in st.session_state:
#     st.session_state.messages = []

# for message in st.session_state.messages:
#     with st.chat_message(message["role"]):
#         st.markdown(message["content"])


def clean_answer(s):
    # Sử dụng regex để loại bỏ tất cả các ký tự đặc biệt ở cuối chuỗi
    return re.sub(r'[^aAàÀảẢáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0-9]+$', '', s)
    
# if prompt := st.chat_input("What is up?"):
#         st.session_state.messages.append({"role": "user", "content": prompt})
#         with st.chat_message("user"):
#             st.markdown(prompt)
#         response = clean_answer(question_answering(prompt))
#         with st.chat_message("assistant"):
#             st.markdown(response)
        
#         st.session_state.messages.append({"role": "assistant", "content": response})

if 'messages' not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    if message['role'] == 'assistant':
        avatar_class = "assistant-avatar"
        message_class = "assistant-message"
        avatar = './app/static/AI.png'
    else:
        avatar_class = "user-avatar"
        message_class = "user-message"
        avatar = './app/static/human.jpg'
    st.markdown(f"""
    <div class="{message_class}">
        <img src="{avatar}" class="{avatar_class}" />
        <div class="stMarkdown">{message['content']}</div>
    </div>
    """, unsafe_allow_html=True)

if prompt := st.chat_input(placeholder='Xin chào, tôi có thể giúp được gì cho bạn?'):
    st.markdown(f"""
    <div class="user-message">
        <img src="./app/static/human.jpg" class="user-avatar" />
        <div class="stMarkdown">{prompt}</div>
    </div>
    """, unsafe_allow_html=True)
    st.session_state.messages.append({'role': 'user', 'content': prompt})
    
    respond = clean_answer(question_answering(prompt))

    st.markdown(f"""
    <div class="assistant-message">
        <img src="./app/static/AI.png" class="assistant-avatar" />
        <div class="stMarkdown">{respond}</div>
    </div>
    """, unsafe_allow_html=True)
    st.session_state.messages.append({'role': 'assistant', 'content': respond})