|
import os |
|
import tempfile |
|
from langchain.vectorstores import FAISS |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.docstore.document import Document |
|
from langchain.chains import RetrievalQA |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain.embeddings.base import Embeddings |
|
|
|
|
|
CACHE_DIR = tempfile.gettempdir() |
|
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR |
|
os.environ["HF_HOME"] = CACHE_DIR |
|
|
|
DATA_PATH = "/app/data" |
|
VECTORSTORE_PATH = "/app/vectorstore" |
|
DOCS_FILENAME = "context.txt" |
|
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2" |
|
|
|
def load_embedding_model() -> Embeddings: |
|
return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) |
|
|
|
def load_documents() -> list[Document]: |
|
loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME)) |
|
raw_docs = loader.load() |
|
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
return splitter.split_documents(raw_docs) |
|
|
|
def load_vectorstore() -> FAISS: |
|
vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index") |
|
embedding_model = load_embedding_model() |
|
if os.path.exists(vectorstore_file): |
|
return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True) |
|
docs = load_documents() |
|
vectorstore = FAISS.from_documents(docs, embedding_model) |
|
vectorstore.save_local(vectorstore_file) |
|
return vectorstore |
|
|
|
def ask_question(query: str) -> str: |
|
vectorstore = load_vectorstore() |
|
llm = HuggingFaceHub( |
|
repo_id="mistralai/Mistral-7B-Instruct-v0.1", |
|
model_kwargs={"temperature": 0.5, "max_tokens": 256}, |
|
) |
|
qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever()) |
|
return qa.run(query) |
|
|