File size: 1,887 Bytes
982eab4 a3c88c4 d665d4a a3c88c4 e2765f4 d665d4a e2765f4 d665d4a a3c88c4 e2765f4 a3c88c4 d4c3edd 982eab4 d665d4a 982eab4 d665d4a 982eab4 a3c88c4 d665d4a a3c88c4 d665d4a a3c88c4 982eab4 d665d4a a3c88c4 d665d4a a3c88c4 d665d4a 982eab4 d665d4a a3c88c4 d665d4a 982eab4 a3c88c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import tempfile
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFaceHub
from langchain.embeddings.base import Embeddings
# Use /tmp for writeable cache
CACHE_DIR = tempfile.gettempdir()
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
DATA_PATH = "/app/data"
VECTORSTORE_PATH = "/app/vectorstore"
DOCS_FILENAME = "context.txt"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"
def load_embedding_model() -> Embeddings:
return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
def load_documents() -> list[Document]:
loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME))
raw_docs = loader.load()
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
return splitter.split_documents(raw_docs)
def load_vectorstore() -> FAISS:
vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index")
embedding_model = load_embedding_model()
if os.path.exists(vectorstore_file):
return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True)
docs = load_documents()
vectorstore = FAISS.from_documents(docs, embedding_model)
vectorstore.save_local(vectorstore_file)
return vectorstore
def ask_question(query: str) -> str:
vectorstore = load_vectorstore()
llm = HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.1",
model_kwargs={"temperature": 0.5, "max_tokens": 256},
)
qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
return qa.run(query)
|