RAG-Xpert / src /rag_methods.py
TechyCode's picture
Upload 3 files
a2ac738 verified
raw
history blame
6.67 kB
import os
import dotenv
from time import time
import streamlit as st
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders import (
WebBaseLoader,
PyPDFLoader,
Docx2txtLoader,
)
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
dotenv.load_dotenv()
os.environ["USER_AGENT"] = "myagent"
DB_DOCS_LIMIT = 10
# Stream non-RAG LLM response
def stream_llm_response(llm_stream, messages):
response_message = ""
for chunk in llm_stream.stream(messages):
response_message += chunk.content
yield chunk
st.session_state.messages.append({"role": "assistant", "content": response_message})
# --- Document Loading and Indexing ---
def load_doc_to_db():
if "rag_docs" in st.session_state and st.session_state.rag_docs:
docs = []
for doc_file in st.session_state.rag_docs:
if doc_file.name not in st.session_state.rag_sources:
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
os.makedirs("source_files", exist_ok=True)
file_path = f"./source_files/{doc_file.name}"
with open(file_path, "wb") as file:
file.write(doc_file.read())
try:
if doc_file.type == "application/pdf":
loader = PyPDFLoader(file_path)
elif doc_file.name.endswith(".docx"):
loader = Docx2txtLoader(file_path)
elif doc_file.type in ["text/plain", "text/markdown"]:
loader = TextLoader(file_path)
else:
st.warning(f"Unsupported document type: {doc_file.type}")
continue
docs.extend(loader.load())
st.session_state.rag_sources.append(doc_file.name)
except Exception as e:
st.toast(f"Error loading document {doc_file.name}: {e}", icon="⚠️")
finally:
os.remove(file_path)
else:
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
if docs:
_split_and_load_docs(docs)
st.toast(f"Documents loaded successfully.", icon="βœ…")
def load_url_to_db():
if "rag_url" in st.session_state and st.session_state.rag_url:
url = st.session_state.rag_url
docs = []
if url not in st.session_state.rag_sources:
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
try:
loader = WebBaseLoader(url)
docs.extend(loader.load())
st.session_state.rag_sources.append(url)
except Exception as e:
st.error(f"Error loading from URL {url}: {e}")
if docs:
_split_and_load_docs(docs)
st.toast(f"Loaded content from URL: {url}", icon="βœ…")
else:
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
def initialize_vector_db(docs):
# Initialize HuggingFace embeddings
embedding = HuggingFaceEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': False}
)
# Shared persistent directory for long-term storage
persist_dir = "./chroma_persistent_db"
collection_name = "persistent_collection"
# Create the persistent Chroma vector store
vector_db = Chroma.from_documents(
documents=docs,
embedding=embedding,
persist_directory=persist_dir,
collection_name=collection_name
)
# Persist to disk
vector_db.persist()
return vector_db
def _split_and_load_docs(docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
chunks = text_splitter.split_documents(docs)
if "vector_db" not in st.session_state:
st.session_state.vector_db = initialize_vector_db(chunks)
else:
st.session_state.vector_db.add_documents(chunks)
st.session_state.vector_db.persist() # Save changes
# --- RAG Chain ---
def _get_context_retriever_chain(vector_db, llm):
retriever = vector_db.as_retriever()
prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="messages"),
("user", "{input}"),
("user", "Given the above conversation, generate a search query to find relevant information.")
])
return create_history_aware_retriever(llm, retriever, prompt)
def get_conversational_rag_chain(llm):
retriever_chain = _get_context_retriever_chain(st.session_state.vector_db, llm)
prompt = ChatPromptTemplate.from_messages([
("system",
"""You are a helpful assistant answering the user's queries using the provided context if available.\n
{context}"""),
MessagesPlaceholder(variable_name="messages"),
("user", "{input}")
])
stuff_documents_chain = create_stuff_documents_chain(llm, prompt)
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
# Stream RAG LLM response
def stream_llm_rag_response(llm_stream, messages):
rag_chain = get_conversational_rag_chain(llm_stream)
# Extract latest user input and prior messages
input_text = messages[-1].content
history = messages[:-1]
# --- DEBUG: Show context retrieved ---
if st.session_state.get("debug_mode"):
retriever = st.session_state.vector_db.as_retriever()
retrieved_docs = retriever.get_relevant_documents(input_text)
st.markdown("### πŸ” Retrieved Context (Debug Mode)")
for i, doc in enumerate(retrieved_docs):
st.markdown(f"**Chunk {i+1}:**\n```\n{doc.page_content.strip()}\n```")
response_message = "*(RAG Response)*\n"
response = rag_chain.stream({
"messages": history,
"input": input_text
})
for chunk in response:
if 'answer' in chunk:
response_message += chunk['answer']
yield chunk['answer']
st.session_state.messages.append({"role": "assistant", "content": response_message})