RAG-Xpert / src /rag_methods.py
TechyCode's picture
Update src/rag_methods.py
13ee6ba verified
import os
import dotenv
from time import time
import streamlit as st
import logging
# Configure environment for Hugging Face Spaces
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface"
# Create necessary directories
os.makedirs("/tmp/.cache/huggingface", exist_ok=True)
os.makedirs("/tmp/chroma_persistent_db", exist_ok=True)
os.makedirs("/tmp/source_files", exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders import (
WebBaseLoader,
PyPDFLoader,
Docx2txtLoader,
)
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
dotenv.load_dotenv()
os.environ["USER_AGENT"] = "myagent"
DB_DOCS_LIMIT = 10
def clean_temp_files():
"""Clean up temporary files to prevent storage issues"""
try:
for folder in ["/tmp/source_files"]:
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
def stream_llm_response(llm_stream, messages):
response_message = ""
for chunk in llm_stream.stream(messages):
response_message += chunk.content
yield chunk
st.session_state.messages.append({"role": "assistant", "content": response_message})
def load_doc_to_db():
if "rag_docs" in st.session_state and st.session_state.rag_docs:
docs = []
for doc_file in st.session_state.rag_docs:
if doc_file.name not in st.session_state.rag_sources:
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
try:
file_path = f"/tmp/source_files/{doc_file.name}"
with open(file_path, "wb") as file:
file.write(doc_file.getbuffer())
if doc_file.type == "application/pdf":
loader = PyPDFLoader(file_path)
elif doc_file.name.endswith(".docx"):
loader = Docx2txtLoader(file_path)
elif doc_file.type in ["text/plain", "text/markdown"]:
loader = TextLoader(file_path)
else:
st.warning(f"Unsupported document type: {doc_file.type}")
continue
docs.extend(loader.load())
st.session_state.rag_sources.append(doc_file.name)
logger.info(f"Successfully loaded document: {doc_file.name}")
except Exception as e:
st.toast(f"Error loading document {doc_file.name}: {str(e)}", icon="⚠️")
logger.error(f"Error loading document: {e}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
else:
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
if docs:
_split_and_load_docs(docs)
st.toast("Documents loaded successfully.", icon="✅")
clean_temp_files()
def load_url_to_db():
if "rag_url" in st.session_state and st.session_state.rag_url:
url = st.session_state.rag_url
docs = []
if url not in st.session_state.rag_sources:
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
try:
loader = WebBaseLoader(url)
docs.extend(loader.load())
st.session_state.rag_sources.append(url)
logger.info(f"Successfully loaded URL: {url}")
except Exception as e:
st.error(f"Error loading from URL {url}: {str(e)}")
logger.error(f"Error loading URL: {e}")
if docs:
_split_and_load_docs(docs)
st.toast(f"Loaded content from URL: {url}", icon="✅")
else:
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
def initialize_vector_db(docs):
embedding = HuggingFaceEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': False},
cache_folder="/tmp/.cache"
)
persist_dir = "/tmp/chroma_persistent_db"
collection_name = "persistent_collection"
vector_db = Chroma.from_documents(
documents=docs,
embedding=embedding,
persist_directory=persist_dir,
collection_name=collection_name
)
vector_db.persist()
logger.info("Vector database initialized and persisted")
return vector_db
def _split_and_load_docs(docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
chunks = text_splitter.split_documents(docs)
if "vector_db" not in st.session_state:
st.session_state.vector_db = initialize_vector_db(chunks)
else:
st.session_state.vector_db.add_documents(chunks)
st.session_state.vector_db.persist()
logger.info("Added new documents to existing vector database")
def _get_context_retriever_chain(vector_db, llm):
retriever = vector_db.as_retriever()
prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="messages"),
("user", "{input}"),
("user", "Given the above conversation, generate a search query to find relevant information.")
])
return create_history_aware_retriever(llm, retriever, prompt)
def get_conversational_rag_chain(llm):
retriever_chain = _get_context_retriever_chain