| import math | |
| import os | |
| import re | |
| from pathlib import Path | |
| from statistics import median | |
| import streamlit as st | |
| from bs4 import BeautifulSoup | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.docstore.document import Document | |
| from langchain.document_loaders import PDFMinerPDFasHTMLLoader, WebBaseLoader | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain_openai import ChatOpenAI, OpenAI | |
| from langchain.vectorstores import Chroma | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from ragatouille import RAGPretrainedModel | |
| st.set_page_config(layout="wide") | |
| os.environ["OPENAI_API_KEY"] = "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" | |
| LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store") | |
| deep_strip = lambda text: re.sub(r"\s+", " ", text or "").strip() | |
| def embeddings_on_local_vectordb(texts): | |
| colbert = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv1.9") | |
| colbert.index( | |
| collection=[chunk.page_content for chunk in texts], | |
| split_documents=False, | |
| document_metadatas=[chunk.metadata for chunk in texts], | |
| index_name="vector_store", | |
| ) | |
| retriever = colbert.as_langchain_retriever(k=5) | |
| retriever = MultiQueryRetriever.from_llm( | |
| retriever=retriever, llm=ChatOpenAI(temperature=0) | |
| ) | |
| return retriever | |
| def query_llm(retriever, query): | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), | |
| retriever=retriever, | |
| return_source_documents=True, | |
| chain_type="stuff", | |
| ) | |
| relevant_docs = retriever.get_relevant_documents(query) | |
| result = qa_chain({"question": query, "chat_history": st.session_state.messages}) | |
| result = result["answer"] | |
| st.session_state.messages.append((query, result)) | |
| return relevant_docs, result | |
| def input_fields(): | |
| st.session_state.source_doc_urls = [ | |
| url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",") | |
| ] | |
| def process_documents(): | |
| try: | |
| snippets = [] | |
| for url in st.session_state.source_doc_urls: | |
| if url.endswith(".pdf"): | |
| snippets.extend(process_pdf(url)) | |
| else: | |
| snippets.extend(process_web(url)) | |
| st.session_state.retriever = embeddings_on_local_vectordb(snippets) | |
| st.session_state.headers = [ | |
| " ".join(snip.metadata["header"].split()[:10]) for snip in snippets | |
| ] | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| def process_pdf(url): | |
| data = PDFMinerPDFasHTMLLoader(url).load()[0] | |
| content = BeautifulSoup(data.page_content, "html.parser").find_all("div") | |
| snippets = get_pdf_snippets(content) | |
| filtered_snippets = filter_pdf_snippets(snippets, new_line_threshold_ratio=0.4) | |
| median_font_size = math.ceil( | |
| median([font_size for _, font_size in filtered_snippets]) | |
| ) | |
| semantic_snippets = get_pdf_semantic_snippets(filtered_snippets, median_font_size) | |
| document_snippets = [ | |
| Document( | |
| page_content=deep_strip(snip[1]["header_text"]) + " " + deep_strip(snip[0]), | |
| metadata={ | |
| "header": deep_strip(snip[1]["header_text"]), | |
| "source_url": url, | |
| "source_type": "pdf", | |
| }, | |
| ) | |
| for snip in semantic_snippets | |
| ] | |
| return document_snippets | |
| def get_pdf_snippets(content): | |
| current_font_size = None | |
| current_text = "" | |
| snippets = [] | |
| for cntnt in content: | |
| span = cntnt.find("span") | |
| if not span: | |
| continue | |
| style = span.get("style") | |
| if not style: | |
| continue | |
| font_size = re.findall("font-size:(\d+)px", style) | |
| if not font_size: | |
| continue | |
| font_size = int(font_size[0]) | |
| if not current_font_size: | |
| current_font_size = font_size | |
| if font_size == current_font_size: | |
| current_text += cntnt.text | |
| else: | |
| snippets.append((current_text, current_font_size)) | |
| current_font_size = font_size | |
| current_text = cntnt.text | |
| snippets.append((current_text, current_font_size)) | |
| return snippets | |
| def filter_pdf_snippets(content_list, new_line_threshold_ratio): | |
| filtered_list = [] | |
| for e, (content, font_size) in enumerate(content_list): | |
| newline_count = content.count("\n") | |
| total_chars = len(content) | |
| ratio = newline_count / total_chars | |
| if ratio <= new_line_threshold_ratio: | |
| filtered_list.append((content, font_size)) | |
| return filtered_list | |
| def get_pdf_semantic_snippets(filtered_snippets, median_font_size): | |
| semantic_snippets = [] | |
| current_header = None | |
| current_content = [] | |
| header_font_size = None | |
| content_font_sizes = [] | |
| for content, font_size in filtered_snippets: | |
| if font_size > median_font_size: | |
| if current_header is not None: | |
| metadata = { | |
| "header_font_size": header_font_size, | |
| "content_font_size": ( | |
| median(content_font_sizes) if content_font_sizes else None | |
| ), | |
| "header_text": current_header, | |
| } | |
| semantic_snippets.append((current_content, metadata)) | |
| current_content = [] | |
| content_font_sizes = [] | |
| current_header = content | |
| header_font_size = font_size | |
| else: | |
| content_font_sizes.append(font_size) | |
| if current_content: | |
| current_content += " " + content | |
| else: | |
| current_content = content | |
| if current_header is not None: | |
| metadata = { | |
| "header_font_size": header_font_size, | |
| "content_font_size": ( | |
| median(content_font_sizes) if content_font_sizes else None | |
| ), | |
| "header_text": current_header, | |
| } | |
| semantic_snippets.append((current_content, metadata)) | |
| return semantic_snippets | |
| def process_web(url): | |
| data = WebBaseLoader(url).load()[0] | |
| document_snippets = [ | |
| Document( | |
| page_content=deep_strip(data.page_content), | |
| metadata={ | |
| "header": data.metadata["title"], | |
| "source_url": url, | |
| "source_type": "web", | |
| }, | |
| ) | |
| ] | |
| return document_snippets | |
| def boot(): | |
| st.title("Xi Chatbot") | |
| input_fields() | |
| col1, col2 = st.columns([4, 1]) | |
| st.sidebar.button("Submit Documents", on_click=process_documents) | |
| if "headers" in st.session_state: | |
| for header in st.session_state.headers: | |
| col2.info(header) | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| col1.chat_message("human").write(message[0]) | |
| col1.chat_message("ai").write(message[1]) | |
| if query := col1.chat_input(): | |
| col1.chat_message("human").write(query) | |
| references, response = query_llm(st.session_state.retriever, query) | |
| for snip in references: | |
| st.sidebar.success( | |
| f'Section {" ".join(snip.metadata["header"].split()[:10])}' | |
| ) | |
| col1.chat_message("ai").write(response) | |
| if __name__ == "__main__": | |
| boot() | |