File size: 5,799 Bytes
cb57b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import shutil
import streamlit as st
from huggingface_hub import login
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core import get_response_synthesizer
from llama_index.core import Settings

CHUNK_SIZE = 1024
CHUNK_OVERLAP = 128
TOP_K = 10
SIMILARITY_CUTOFF = 0.6
MAX_SELECTED_NODES = 5
TEMP_FILES_DIR = "./temp_files"

st.set_page_config(
    page_title="AIVN - RAG with Llama Index",
    page_icon="./static/aivn_favicon.png",
    layout="wide",
    initial_sidebar_state="expanded"
)

st.image("./static/aivn_logo.png", width=300)

if 'run_count' not in st.session_state:
    st.session_state['run_count'] = 0

st.session_state['run_count'] += 1
if st.session_state['run_count'] == 1:
    if os.path.exists(TEMP_FILES_DIR):
        shutil.rmtree(TEMP_FILES_DIR)
    os.makedirs(TEMP_FILES_DIR, exist_ok=True)
    st.cache_resource.clear()
    
# st.write(f"Ứng dụng đã chạy {st.session_state['run_count']} lần.")


class SortedRetrieverQueryEngine(RetrieverQueryEngine):
    def retrieve(self, query):
        nodes = self.retriever.retrieve(query)
        filtered_nodes = [node for node in nodes if node.score >= SIMILARITY_CUTOFF]
        sorted_nodes = sorted(filtered_nodes, key=lambda node: node.score, reverse=True)
        return sorted_nodes[:MAX_SELECTED_NODES]

st.title("Retrieval-Augmented Generation (RAG) Demo")

hf_api_key_placeholder = st.empty()
hf_api_key = hf_api_key_placeholder.text_input("Enter your Hugging Face API Key", type="password", placeholder="hf_...", key="hf_api_key")
st.markdown("Don't have an API key? Get one [here](https://huggingface.co/settings/tokens) (**Read Token** is enough)")

if hf_api_key:
    @st.cache_resource
    def load_models(hf_api_key):
        login(token=hf_api_key)
        with st.spinner("Loading models from Hugging Face..."):
            llm = HuggingFaceInferenceAPI(
                model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_api_key)
            embed_model = HuggingFaceEmbedding(model_name=f'BAAI/bge-small-en-v1.5', token=hf_api_key)
        return llm, embed_model
    
    llm, embed_model = load_models(hf_api_key)

    uploaded_files = st.file_uploader("Upload documents", accept_multiple_files=True, key="uploaded_files")
    if uploaded_files:
        @st.cache_resource
        def uploading_files(uploaded_files, num_documents):
            with st.spinner("Processing uploaded files..."):
                file_paths = []
                for i, uploaded_file in enumerate(uploaded_files):
                    file_path = os.path.join(TEMP_FILES_DIR, uploaded_file.name)
                    file_paths.append(file_path)
                    with open(file_path, "wb") as f:
                        f.write(uploaded_file.getbuffer())
                
                st.write(f"Uploaded {len(uploaded_files)}/{num_documents} files")
                
                return SimpleDirectoryReader(TEMP_FILES_DIR).load_data()
        
        num_documents = len(uploaded_files)
        documents = uploading_files(uploaded_files, num_documents)

        @st.cache_resource
        def indexing(_documents, _embed_model, num_documents):
            with st.spinner("Indexing documents..."):
                text_splitter = SentenceSplitter(
                    chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
                Settings.text_splitter = text_splitter

            st.write(f"Indexing {num_documents} documents")
            return VectorStoreIndex.from_documents(
                    _documents, transformations=[text_splitter], embed_model=_embed_model, show_progress=True
                )
        
        index = indexing(documents, embed_model, num_documents)

        @st.cache_resource
        def create_retriever_and_query_engine(_index, _llm, num_documents):
            retriever = VectorIndexRetriever(
                index=_index, similarity_top_k=TOP_K)

            response_synthesizer = get_response_synthesizer(llm=_llm)

            st.write(f"Querying with {num_documents} nodes")
            
            return SortedRetrieverQueryEngine(
                retriever=retriever,
                response_synthesizer=response_synthesizer,
                node_postprocessors=[],
            )
        
        query_engine = create_retriever_and_query_engine(index, llm, len(index.docstore.docs))
        
        query = st.text_input("Enter your query for RAG", key="query")

        if query:
            with st.spinner("Querying..."):
                response = query_engine.query(query)
                retrieved_nodes = response.source_nodes

                st.markdown("### Retrieved Documents")
                for i, node in enumerate(retrieved_nodes):
                    with st.expander(f"Document {i+1} (Score: {node.score:.4f})"):
                        st.write(node.text)

                st.markdown("### RAG Response:")
                st.write(response.response)
    
st.markdown(
    """
    <style>
    .footer {
        position: fixed;
        bottom: 0;
        left: 0;
        width: 100%;
        background-color: #f1f1f1;
        text-align: center;
        padding: 10px 0;
        font-size: 14px;
        color: #555;
    }
    </style>
    <div class="footer">
        2024 AI VIETNAM | Made by <a href="https://github.com/Koii2k3/Basic-RAG-LlamaIndex" target="_blank">Koii2k3</a>
    </div>
    """,
    unsafe_allow_html=True
)