|
import os |
|
import shutil |
|
|
|
from langchain.document_loaders import ( |
|
PyMuPDFLoader, |
|
) |
|
from langchain.docstore.document import Document |
|
|
|
from langchain.vectorstores import Chroma |
|
|
|
from langchain.text_splitter import ( |
|
RecursiveCharacterTextSplitter, |
|
SpacyTextSplitter, |
|
) |
|
|
|
def load_pdf_as_docs(pdf_path, loader_module=None, load_kwargs=None): |
|
"""Load and parse pdf file(s).""" |
|
|
|
if pdf_path.endswith('.pdf'): |
|
pdf_docs = [pdf_path] |
|
else: |
|
pdf_docs = [os.path.join(pdf_path, f) for f in os.listdir(pdf_path) if f.endswith('.pdf')] |
|
|
|
if load_kwargs is None: |
|
load_kwargs = {} |
|
|
|
docs = [] |
|
if loader_module is None: |
|
loader_module = PyMuPDFLoader |
|
for pdf in pdf_docs: |
|
loader = loader_module(pdf, **load_kwargs) |
|
doc = loader.load() |
|
docs.extend(doc) |
|
|
|
return docs |
|
|
|
def load_xml_as_docs(xml_path, loader_module=None, load_kwargs=None): |
|
"""Load and parse xml file(s).""" |
|
|
|
from bs4 import BeautifulSoup |
|
from unstructured.cleaners.core import group_broken_paragraphs |
|
|
|
if xml_path.endswith('.xml'): |
|
xml_docs = [xml_path] |
|
else: |
|
xml_docs = [os.path.join(xml_path, f) for f in os.listdir(xml_path) if f.endswith('.xml')] |
|
|
|
if load_kwargs is None: |
|
load_kwargs = {} |
|
|
|
docs = [] |
|
for xml_file in xml_docs: |
|
|
|
with open(xml_file) as fp: |
|
soup = BeautifulSoup(fp, features="xml") |
|
pageText = soup.findAll(string=True) |
|
parsed_text = '\n'.join(pageText) |
|
|
|
parsed_text_grouped = group_broken_paragraphs(parsed_text) |
|
|
|
|
|
try: |
|
from lxml import etree as ET |
|
tree = ET.parse(xml_file) |
|
|
|
|
|
ns = {"tei": "http://www.tei-c.org/ns/1.0"} |
|
|
|
pers_name_elements = tree.xpath("tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:author/tei:persName", namespaces=ns) |
|
first_per = pers_name_elements[0].text |
|
author_info = first_per + " et al" |
|
|
|
title_elements = tree.xpath("tei:teiHeader/tei:fileDesc/tei:titleStmt/tei:title", namespaces=ns) |
|
title = title_elements[0].text |
|
|
|
|
|
source_info = "_".join([author_info, title]) |
|
except: |
|
source_info = "unknown" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
doc = [Document(page_content=parsed_text_grouped, metadata={"source": source_info})] |
|
|
|
docs.extend(doc) |
|
|
|
return docs |
|
|
|
|
|
def get_doc_chunks(docs, splitter=None): |
|
"""Split docs into chunks.""" |
|
|
|
if splitter is None: |
|
|
|
|
|
|
|
|
|
splitter = SpacyTextSplitter.from_tiktoken_encoder( |
|
chunk_size=512, |
|
chunk_overlap=128, |
|
) |
|
chunks = splitter.split_documents(docs) |
|
|
|
return chunks |
|
|
|
|
|
def persist_vectorstore(document_chunks, embeddings, persist_directory="db", overwrite=False): |
|
|
|
|
|
if overwrite: |
|
shutil.rmtree(persist_directory) |
|
db = Chroma.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory) |
|
|
|
db.persist() |
|
|
|
|
|
|
|
return db |
|
|
|
|
|
class VectorstoreManager: |
|
|
|
def __init__(self): |
|
self.vectorstore_class = Chroma |
|
|
|
def create_db(self, embeddings): |
|
db = self.vectorstore_class(embedding_function=embeddings) |
|
|
|
self.db = db |
|
return db |
|
|
|
|
|
def load_db(self, persist_directory, embeddings): |
|
"""Load local vectorestore.""" |
|
|
|
db = self.vectorstore_class(persist_directory=persist_directory, embedding_function=embeddings) |
|
self.db = db |
|
|
|
return db |
|
|
|
def create_db_from_documents(self, document_chunks, embeddings, persist_directory="db", overwrite=False): |
|
"""Create db from documents.""" |
|
|
|
if overwrite: |
|
shutil.rmtree(persist_directory) |
|
db = self.vectorstore_class.from_documents(documents=document_chunks, embedding=embeddings, persist_directory=persist_directory) |
|
self.db = db |
|
|
|
return db |
|
|
|
def persist_db(self, persist_directory="db"): |
|
"""Persist db.""" |
|
|
|
assert self.db |
|
self.db.persist() |
|
|
|
class RetrieverManager: |
|
|
|
|
|
def __init__(self, vectorstore, k=10): |
|
|
|
self.vectorstore = vectorstore |
|
self.retriever = vectorstore.as_retriever(search_kwargs={"k": k}) |
|
|
|
def get_rerank_retriver(self, base_retriever=None): |
|
|
|
if base_retriever is None: |
|
base_retriever = self.retriever |
|
|
|
from rerank import BgeRerank |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
|
compressor = BgeRerank() |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=compressor, base_retriever=base_retriever |
|
) |
|
|
|
return compression_retriever |
|
|
|
def get_parent_doc_retriver(self, documents, store_file="./store_location"): |
|
|
|
|
|
from langchain.storage.file_system import LocalFileStore |
|
from langchain.storage import InMemoryStore |
|
from langchain.storage._lc_store import create_kv_docstore |
|
from langchain.retrievers import ParentDocumentRetriever |
|
|
|
|
|
|
|
docstore = InMemoryStore() |
|
|
|
|
|
parent_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1024, chunk_overlap=256) |
|
child_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=256, chunk_overlap=128) |
|
|
|
retriever = ParentDocumentRetriever( |
|
vectorstore=self.vectorstore, |
|
docstore=docstore, |
|
child_splitter=child_splitter, |
|
parent_splitter=parent_splitter, |
|
search_kwargs={"k":10} |
|
) |
|
retriever.add_documents(documents) |
|
|
|
return retriever |
|
|