|
import os |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain.docstore.document import Document |
|
|
|
|
|
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
prompt_template = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template="""You are an intelligent assistant. Use the context below to answer the question. |
|
If the answer is not contained in the context, say "I don't know." |
|
|
|
Context: {context} |
|
Question: {question} |
|
Answer:""" |
|
) |
|
|
|
def create_vectorstore(doc_path: str = "data/docs.txt"): |
|
"""Create or load FAISS vectorstore from the given document.""" |
|
loader = TextLoader(doc_path) |
|
documents = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) |
|
docs = text_splitter.split_documents(documents) |
|
|
|
|
|
vectordb = FAISS.from_documents(docs, embedding_model) |
|
vectordb.save_local("vectorstore") |
|
return vectordb |
|
|
|
def load_vectorstore(): |
|
"""Load existing FAISS vectorstore from disk.""" |
|
return FAISS.load_local("vectorstore", embedding_model, allow_dangerous_deserialization=True) |
|
|
|
def get_llm(): |
|
"""Load the HuggingFace Mistral LLM.""" |
|
return HuggingFaceHub( |
|
repo_id="mistralai/Mistral-7B-Instruct-v0.1", |
|
model_kwargs={"temperature": 0.5, "max_new_tokens": 512}, |
|
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN |
|
) |
|
|
|
def build_qa_chain(): |
|
"""Build the full RAG QA chain.""" |
|
vectordb = load_vectorstore() |
|
retriever = vectordb.as_retriever() |
|
llm = get_llm() |
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": prompt_template} |
|
) |
|
return qa_chain |
|
|
|
def ask_question(query: str) -> dict: |
|
"""Handle a single user query.""" |
|
chain = build_qa_chain() |
|
result = chain({"query": query}) |
|
return { |
|
"answer": result["result"], |
|
"sources": [doc.metadata.get("source", "unknown") for doc in result["source_documents"]] |
|
} |
|
|