File size: 5,799 Bytes
cb57b6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import shutil
import streamlit as st
from huggingface_hub import login
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import SimpleDirectoryReader
from llama_index.core import get_response_synthesizer
from llama_index.core import Settings
CHUNK_SIZE = 1024
CHUNK_OVERLAP = 128
TOP_K = 10
SIMILARITY_CUTOFF = 0.6
MAX_SELECTED_NODES = 5
TEMP_FILES_DIR = "./temp_files"
st.set_page_config(
page_title="AIVN - RAG with Llama Index",
page_icon="./static/aivn_favicon.png",
layout="wide",
initial_sidebar_state="expanded"
)
st.image("./static/aivn_logo.png", width=300)
if 'run_count' not in st.session_state:
st.session_state['run_count'] = 0
st.session_state['run_count'] += 1
if st.session_state['run_count'] == 1:
if os.path.exists(TEMP_FILES_DIR):
shutil.rmtree(TEMP_FILES_DIR)
os.makedirs(TEMP_FILES_DIR, exist_ok=True)
st.cache_resource.clear()
# st.write(f"Ứng dụng đã chạy {st.session_state['run_count']} lần.")
class SortedRetrieverQueryEngine(RetrieverQueryEngine):
def retrieve(self, query):
nodes = self.retriever.retrieve(query)
filtered_nodes = [node for node in nodes if node.score >= SIMILARITY_CUTOFF]
sorted_nodes = sorted(filtered_nodes, key=lambda node: node.score, reverse=True)
return sorted_nodes[:MAX_SELECTED_NODES]
st.title("Retrieval-Augmented Generation (RAG) Demo")
hf_api_key_placeholder = st.empty()
hf_api_key = hf_api_key_placeholder.text_input("Enter your Hugging Face API Key", type="password", placeholder="hf_...", key="hf_api_key")
st.markdown("Don't have an API key? Get one [here](https://huggingface.co/settings/tokens) (**Read Token** is enough)")
if hf_api_key:
@st.cache_resource
def load_models(hf_api_key):
login(token=hf_api_key)
with st.spinner("Loading models from Hugging Face..."):
llm = HuggingFaceInferenceAPI(
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_api_key)
embed_model = HuggingFaceEmbedding(model_name=f'BAAI/bge-small-en-v1.5', token=hf_api_key)
return llm, embed_model
llm, embed_model = load_models(hf_api_key)
uploaded_files = st.file_uploader("Upload documents", accept_multiple_files=True, key="uploaded_files")
if uploaded_files:
@st.cache_resource
def uploading_files(uploaded_files, num_documents):
with st.spinner("Processing uploaded files..."):
file_paths = []
for i, uploaded_file in enumerate(uploaded_files):
file_path = os.path.join(TEMP_FILES_DIR, uploaded_file.name)
file_paths.append(file_path)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.write(f"Uploaded {len(uploaded_files)}/{num_documents} files")
return SimpleDirectoryReader(TEMP_FILES_DIR).load_data()
num_documents = len(uploaded_files)
documents = uploading_files(uploaded_files, num_documents)
@st.cache_resource
def indexing(_documents, _embed_model, num_documents):
with st.spinner("Indexing documents..."):
text_splitter = SentenceSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
Settings.text_splitter = text_splitter
st.write(f"Indexing {num_documents} documents")
return VectorStoreIndex.from_documents(
_documents, transformations=[text_splitter], embed_model=_embed_model, show_progress=True
)
index = indexing(documents, embed_model, num_documents)
@st.cache_resource
def create_retriever_and_query_engine(_index, _llm, num_documents):
retriever = VectorIndexRetriever(
index=_index, similarity_top_k=TOP_K)
response_synthesizer = get_response_synthesizer(llm=_llm)
st.write(f"Querying with {num_documents} nodes")
return SortedRetrieverQueryEngine(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=[],
)
query_engine = create_retriever_and_query_engine(index, llm, len(index.docstore.docs))
query = st.text_input("Enter your query for RAG", key="query")
if query:
with st.spinner("Querying..."):
response = query_engine.query(query)
retrieved_nodes = response.source_nodes
st.markdown("### Retrieved Documents")
for i, node in enumerate(retrieved_nodes):
with st.expander(f"Document {i+1} (Score: {node.score:.4f})"):
st.write(node.text)
st.markdown("### RAG Response:")
st.write(response.response)
st.markdown(
"""
<style>
.footer {
position: fixed;
bottom: 0;
left: 0;
width: 100%;
background-color: #f1f1f1;
text-align: center;
padding: 10px 0;
font-size: 14px;
color: #555;
}
</style>
<div class="footer">
2024 AI VIETNAM | Made by <a href="https://github.com/Koii2k3/Basic-RAG-LlamaIndex" target="_blank">Koii2k3</a>
</div>
""",
unsafe_allow_html=True
) |