import os from pathlib import Path from langchain.chains import ConversationalRetrievalChain from langchain.vectorstores import Chroma from langchain.llms.openai import OpenAIChat, OpenAI from langchain.document_loaders import PyPDFLoader, WebBaseLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings.openai import OpenAIEmbeddings from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import LLMChainExtractor from langchain_experimental.text_splitter import SemanticChunker import streamlit as st LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") def load_documents(): loaders = [ PyPDFLoader(source_doc_url) if source_doc_url.endswith(".pdf") else WebBaseLoader(source_doc_url) for source_doc_url in st.session_state.source_doc_urls ] documents = [] for loader in loaders: documents.extend(loader.load()) return documents def split_documents(documents): text_splitter = SemanticChunker(OpenAIEmbeddings(temperature=0)) texts = text_splitter.split_documents(documents) return texts def embeddings_on_local_vectordb(texts): vectordb = Chroma.from_documents( texts, embedding=OpenAIEmbeddings(temperature=0), persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(), ) vectordb.persist() retriever = ContextualCompressionRetriever( base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)), base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"), ) return retriever def query_llm(retriever, query): qa_chain = ConversationalRetrievalChain.from_llm( llm=OpenAIChat(temperature=0), retriever=retriever, return_source_documents=True, chain_type="refine", ) relevant_docs = retriever.get_relevant_documents(query) result = qa_chain({"question": query, "chat_history": st.session_state.messages}) result = result["answer"] st.session_state.messages.append((query, result)) return relevant_docs, result def input_fields(): os.environ["OPENAI_API_KEY"] = "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" st.session_state.source_doc_urls = [ url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",") ] def process_documents(): try: documents = load_documents() texts = split_documents(documents) st.session_state.retriever = embeddings_on_local_vectordb(texts) except Exception as e: st.error(f"An error occurred: {e}") def boot(): st.title("Enigma Chatbot") input_fields() st.sidebar.button("Submit Documents", on_click=process_documents) st.sidebar.write("---") st.sidebar.write("References made during the chat will appear here") if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: st.chat_message("human").write(message[0]) st.chat_message("ai").write(message[1]) if query := st.chat_input(): st.chat_message("human").write(query) references, response = query_llm(st.session_state.retriever, query) for doc in references: st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}") st.chat_message("ai").write(response) if __name__ == "__main__": boot()