|
import os |
|
import dotenv |
|
from time import time |
|
import streamlit as st |
|
|
|
from langchain_community.document_loaders.text import TextLoader |
|
from langchain_community.document_loaders import ( |
|
WebBaseLoader, |
|
PyPDFLoader, |
|
Docx2txtLoader, |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
dotenv.load_dotenv() |
|
|
|
os.environ["USER_AGENT"] = "myagent" |
|
DB_DOCS_LIMIT = 10 |
|
|
|
|
|
def stream_llm_response(llm_stream, messages): |
|
response_message = "" |
|
for chunk in llm_stream.stream(messages): |
|
response_message += chunk.content |
|
yield chunk |
|
st.session_state.messages.append({"role": "assistant", "content": response_message}) |
|
|
|
|
|
def load_doc_to_db(): |
|
if "rag_docs" in st.session_state and st.session_state.rag_docs: |
|
docs = [] |
|
for doc_file in st.session_state.rag_docs: |
|
if doc_file.name not in st.session_state.rag_sources: |
|
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: |
|
os.makedirs("source_files", exist_ok=True) |
|
file_path = f"./source_files/{doc_file.name}" |
|
with open(file_path, "wb") as file: |
|
file.write(doc_file.read()) |
|
try: |
|
if doc_file.type == "application/pdf": |
|
loader = PyPDFLoader(file_path) |
|
elif doc_file.name.endswith(".docx"): |
|
loader = Docx2txtLoader(file_path) |
|
elif doc_file.type in ["text/plain", "text/markdown"]: |
|
loader = TextLoader(file_path) |
|
else: |
|
st.warning(f"Unsupported document type: {doc_file.type}") |
|
continue |
|
docs.extend(loader.load()) |
|
st.session_state.rag_sources.append(doc_file.name) |
|
except Exception as e: |
|
st.toast(f"Error loading document {doc_file.name}: {e}", icon="β οΈ") |
|
finally: |
|
os.remove(file_path) |
|
else: |
|
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") |
|
if docs: |
|
_split_and_load_docs(docs) |
|
st.toast(f"Documents loaded successfully.", icon="β
") |
|
|
|
def load_url_to_db(): |
|
if "rag_url" in st.session_state and st.session_state.rag_url: |
|
url = st.session_state.rag_url |
|
docs = [] |
|
if url not in st.session_state.rag_sources: |
|
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: |
|
try: |
|
loader = WebBaseLoader(url) |
|
docs.extend(loader.load()) |
|
st.session_state.rag_sources.append(url) |
|
except Exception as e: |
|
st.error(f"Error loading from URL {url}: {e}") |
|
if docs: |
|
_split_and_load_docs(docs) |
|
st.toast(f"Loaded content from URL: {url}", icon="β
") |
|
else: |
|
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") |
|
|
|
def initialize_vector_db(docs): |
|
|
|
embedding = HuggingFaceEmbeddings( |
|
model_name="BAAI/bge-large-en-v1.5", |
|
model_kwargs={'device': 'cpu'}, |
|
encode_kwargs={'normalize_embeddings': False} |
|
) |
|
|
|
|
|
persist_dir = "./chroma_persistent_db" |
|
collection_name = "persistent_collection" |
|
|
|
|
|
vector_db = Chroma.from_documents( |
|
documents=docs, |
|
embedding=embedding, |
|
persist_directory=persist_dir, |
|
collection_name=collection_name |
|
) |
|
|
|
|
|
vector_db.persist() |
|
|
|
return vector_db |
|
|
|
|
|
def _split_and_load_docs(docs): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
) |
|
|
|
chunks = text_splitter.split_documents(docs) |
|
|
|
if "vector_db" not in st.session_state: |
|
st.session_state.vector_db = initialize_vector_db(chunks) |
|
else: |
|
st.session_state.vector_db.add_documents(chunks) |
|
st.session_state.vector_db.persist() |
|
|
|
|
|
|
|
def _get_context_retriever_chain(vector_db, llm): |
|
retriever = vector_db.as_retriever() |
|
prompt = ChatPromptTemplate.from_messages([ |
|
MessagesPlaceholder(variable_name="messages"), |
|
("user", "{input}"), |
|
("user", "Given the above conversation, generate a search query to find relevant information.") |
|
]) |
|
return create_history_aware_retriever(llm, retriever, prompt) |
|
|
|
def get_conversational_rag_chain(llm): |
|
retriever_chain = _get_context_retriever_chain(st.session_state.vector_db, llm) |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", |
|
"""You are a helpful assistant answering the user's queries using the provided context if available.\n |
|
{context}"""), |
|
MessagesPlaceholder(variable_name="messages"), |
|
("user", "{input}") |
|
]) |
|
stuff_documents_chain = create_stuff_documents_chain(llm, prompt) |
|
return create_retrieval_chain(retriever_chain, stuff_documents_chain) |
|
|
|
|
|
def stream_llm_rag_response(llm_stream, messages): |
|
rag_chain = get_conversational_rag_chain(llm_stream) |
|
|
|
|
|
input_text = messages[-1].content |
|
history = messages[:-1] |
|
|
|
|
|
if st.session_state.get("debug_mode"): |
|
retriever = st.session_state.vector_db.as_retriever() |
|
retrieved_docs = retriever.get_relevant_documents(input_text) |
|
st.markdown("### π Retrieved Context (Debug Mode)") |
|
for i, doc in enumerate(retrieved_docs): |
|
st.markdown(f"**Chunk {i+1}:**\n```\n{doc.page_content.strip()}\n```") |
|
|
|
response_message = "*(RAG Response)*\n" |
|
response = rag_chain.stream({ |
|
"messages": history, |
|
"input": input_text |
|
}) |
|
|
|
for chunk in response: |
|
if 'answer' in chunk: |
|
response_message += chunk['answer'] |
|
yield chunk['answer'] |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response_message}) |
|
|
|
|