File size: 1,887 Bytes
982eab4
a3c88c4
d665d4a
a3c88c4
e2765f4
d665d4a
e2765f4
d665d4a
 
a3c88c4
e2765f4
a3c88c4
 
d4c3edd
 
982eab4
d665d4a
 
 
 
982eab4
d665d4a
 
982eab4
a3c88c4
d665d4a
a3c88c4
d665d4a
a3c88c4
982eab4
d665d4a
a3c88c4
d665d4a
a3c88c4
 
 
 
 
d665d4a
982eab4
d665d4a
 
 
a3c88c4
d665d4a
982eab4
a3c88c4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import tempfile
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFaceHub
from langchain.embeddings.base import Embeddings

# Use /tmp for writeable cache
CACHE_DIR = tempfile.gettempdir()
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR

DATA_PATH = "/app/data"
VECTORSTORE_PATH = "/app/vectorstore"
DOCS_FILENAME = "context.txt"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"

def load_embedding_model() -> Embeddings:
    return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)

def load_documents() -> list[Document]:
    loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME))
    raw_docs = loader.load()
    splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    return splitter.split_documents(raw_docs)

def load_vectorstore() -> FAISS:
    vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index")
    embedding_model = load_embedding_model()
    if os.path.exists(vectorstore_file):
        return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True)
    docs = load_documents()
    vectorstore = FAISS.from_documents(docs, embedding_model)
    vectorstore.save_local(vectorstore_file)
    return vectorstore

def ask_question(query: str) -> str:
    vectorstore = load_vectorstore()
    llm = HuggingFaceHub(
        repo_id="mistralai/Mistral-7B-Instruct-v0.1",
        model_kwargs={"temperature": 0.5, "max_tokens": 256},
    )
    qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
    return qa.run(query)