|
import os |
|
from typing import List |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings.base import Embeddings |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.docstore.document import Document |
|
from langchain.chains import RetrievalQA |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
|
|
CACHE_DIR = "/tmp/huggingface" |
|
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR |
|
os.environ["HF_HOME"] = CACHE_DIR |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
DATA_PATH = "/app/data" |
|
VECTORSTORE_PATH = "/app/vectorstore" |
|
DOCS_FILENAME = "context.txt" |
|
VECTORSTORE_INDEX_NAME = "faiss_index" |
|
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2" |
|
LLM_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.1" |
|
|
|
|
|
def load_embedding_model() -> Embeddings: |
|
"""Load Hugging Face sentence transformer embeddings.""" |
|
return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) |
|
|
|
|
|
def load_documents() -> List[Document]: |
|
"""Load documents and split them into manageable chunks.""" |
|
loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME)) |
|
documents = loader.load() |
|
|
|
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
return splitter.split_documents(documents) |
|
|
|
|
|
def load_vectorstore() -> FAISS: |
|
"""Load FAISS vectorstore from disk or create it from documents.""" |
|
vectorstore_dir = os.path.join(VECTORSTORE_PATH, VECTORSTORE_INDEX_NAME) |
|
embedding_model = load_embedding_model() |
|
|
|
if os.path.exists(vectorstore_dir): |
|
return FAISS.load_local( |
|
folder_path=vectorstore_dir, |
|
embeddings=embedding_model, |
|
allow_dangerous_deserialization=True, |
|
) |
|
|
|
documents = load_documents() |
|
vectorstore = FAISS.from_documents(documents, embedding_model) |
|
vectorstore.save_local(vectorstore_dir) |
|
return vectorstore |
|
|
|
|
|
def ask_question(query: str) -> str: |
|
"""Run a question-answering chain with the retriever and language model.""" |
|
vectorstore = load_vectorstore() |
|
retriever = vectorstore.as_retriever() |
|
|
|
llm = HuggingFaceHub( |
|
repo_id=LLM_REPO_ID, |
|
model_kwargs={"temperature": 0.5, "max_tokens": 256}, |
|
) |
|
|
|
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever) |
|
return qa_chain.run(query) |
|
|