|
import os |
|
import dotenv |
|
from time import time |
|
import streamlit as st |
|
import logging |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/.cache/huggingface" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface" |
|
|
|
|
|
os.makedirs("/tmp/.cache/huggingface", exist_ok=True) |
|
os.makedirs("/tmp/chroma_persistent_db", exist_ok=True) |
|
os.makedirs("/tmp/source_files", exist_ok=True) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
from langchain_community.document_loaders.text import TextLoader |
|
from langchain_community.document_loaders import ( |
|
WebBaseLoader, |
|
PyPDFLoader, |
|
Docx2txtLoader, |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
dotenv.load_dotenv() |
|
|
|
os.environ["USER_AGENT"] = "myagent" |
|
DB_DOCS_LIMIT = 10 |
|
|
|
def clean_temp_files(): |
|
"""Clean up temporary files to prevent storage issues""" |
|
try: |
|
for folder in ["/tmp/source_files"]: |
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
if os.path.isfile(file_path): |
|
os.unlink(file_path) |
|
except Exception as e: |
|
logger.warning(f"Error cleaning temp files: {e}") |
|
|
|
def stream_llm_response(llm_stream, messages): |
|
response_message = "" |
|
for chunk in llm_stream.stream(messages): |
|
response_message += chunk.content |
|
yield chunk |
|
st.session_state.messages.append({"role": "assistant", "content": response_message}) |
|
|
|
def load_doc_to_db(): |
|
if "rag_docs" in st.session_state and st.session_state.rag_docs: |
|
docs = [] |
|
for doc_file in st.session_state.rag_docs: |
|
if doc_file.name not in st.session_state.rag_sources: |
|
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: |
|
try: |
|
file_path = f"/tmp/source_files/{doc_file.name}" |
|
with open(file_path, "wb") as file: |
|
file.write(doc_file.getbuffer()) |
|
|
|
if doc_file.type == "application/pdf": |
|
loader = PyPDFLoader(file_path) |
|
elif doc_file.name.endswith(".docx"): |
|
loader = Docx2txtLoader(file_path) |
|
elif doc_file.type in ["text/plain", "text/markdown"]: |
|
loader = TextLoader(file_path) |
|
else: |
|
st.warning(f"Unsupported document type: {doc_file.type}") |
|
continue |
|
|
|
docs.extend(loader.load()) |
|
st.session_state.rag_sources.append(doc_file.name) |
|
logger.info(f"Successfully loaded document: {doc_file.name}") |
|
except Exception as e: |
|
st.toast(f"Error loading document {doc_file.name}: {str(e)}", icon="⚠️") |
|
logger.error(f"Error loading document: {e}") |
|
finally: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
else: |
|
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") |
|
if docs: |
|
_split_and_load_docs(docs) |
|
st.toast("Documents loaded successfully.", icon="✅") |
|
clean_temp_files() |
|
|
|
def load_url_to_db(): |
|
if "rag_url" in st.session_state and st.session_state.rag_url: |
|
url = st.session_state.rag_url |
|
docs = [] |
|
if url not in st.session_state.rag_sources: |
|
if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: |
|
try: |
|
loader = WebBaseLoader(url) |
|
docs.extend(loader.load()) |
|
st.session_state.rag_sources.append(url) |
|
logger.info(f"Successfully loaded URL: {url}") |
|
except Exception as e: |
|
st.error(f"Error loading from URL {url}: {str(e)}") |
|
logger.error(f"Error loading URL: {e}") |
|
if docs: |
|
_split_and_load_docs(docs) |
|
st.toast(f"Loaded content from URL: {url}", icon="✅") |
|
else: |
|
st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") |
|
|
|
def initialize_vector_db(docs): |
|
embedding = HuggingFaceEmbeddings( |
|
model_name="BAAI/bge-large-en-v1.5", |
|
model_kwargs={'device': 'cpu'}, |
|
encode_kwargs={'normalize_embeddings': False}, |
|
cache_folder="/tmp/.cache" |
|
) |
|
|
|
persist_dir = "/tmp/chroma_persistent_db" |
|
collection_name = "persistent_collection" |
|
|
|
vector_db = Chroma.from_documents( |
|
documents=docs, |
|
embedding=embedding, |
|
persist_directory=persist_dir, |
|
collection_name=collection_name |
|
) |
|
|
|
vector_db.persist() |
|
logger.info("Vector database initialized and persisted") |
|
return vector_db |
|
|
|
def _split_and_load_docs(docs): |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
) |
|
|
|
chunks = text_splitter.split_documents(docs) |
|
|
|
if "vector_db" not in st.session_state: |
|
st.session_state.vector_db = initialize_vector_db(chunks) |
|
else: |
|
st.session_state.vector_db.add_documents(chunks) |
|
st.session_state.vector_db.persist() |
|
logger.info("Added new documents to existing vector database") |
|
|
|
def _get_context_retriever_chain(vector_db, llm): |
|
retriever = vector_db.as_retriever() |
|
prompt = ChatPromptTemplate.from_messages([ |
|
MessagesPlaceholder(variable_name="messages"), |
|
("user", "{input}"), |
|
("user", "Given the above conversation, generate a search query to find relevant information.") |
|
]) |
|
return create_history_aware_retriever(llm, retriever, prompt) |
|
|
|
def get_conversational_rag_chain(llm): |
|
retriever_chain = _get_context_retriever_chain |