samim2024 commited on
Commit
e2765f4
·
verified ·
1 Parent(s): 06acafa

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +56 -44
model.py CHANGED
@@ -1,63 +1,75 @@
1
  import os
2
- from PyPDF2 import PdfReader
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.vectorstores import FAISS
6
- from langchain_community.docstore.in_memory import InMemoryDocstore
7
  from langchain_community.llms import HuggingFaceHub
8
- from langchain.chains import RetrievalQA
9
  from langchain.prompts import PromptTemplate
10
- import uuid
11
- import faiss
 
 
 
 
 
12
 
13
- vectorstore = None
 
14
 
15
- def load_vectorstore(pdf_path):
16
- global vectorstore
 
 
 
17
 
18
- reader = PdfReader(pdf_path)
19
- text = "".join([page.extract_text() or "" for page in reader.pages])
20
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
21
- chunks = splitter.split_text(text)
22
 
23
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
24
- dim = len(embeddings.embed_query("test"))
25
- index = faiss.IndexFlatL2(dim)
 
26
 
27
- vectorstore = FAISS(
28
- embedding_function=embeddings,
29
- index=index,
30
- docstore=InMemoryDocstore({}),
31
- index_to_docstore_id={}
32
- )
33
- uuids = [str(uuid.uuid4()) for _ in chunks]
34
- vectorstore.add_texts(chunks, ids=uuids)
35
 
 
 
 
 
36
 
37
- def ask_question(query):
38
- global vectorstore
39
- if not vectorstore:
40
- return "Please upload and index a document first."
41
 
42
- llm = HuggingFaceHub(
 
 
43
  repo_id="mistralai/Mistral-7B-Instruct-v0.1",
44
- huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
45
- model_kwargs={"temperature": 0.7, "max_length": 512}
46
  )
47
 
48
- retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
49
- prompt = PromptTemplate(
50
- template="Use the context to answer the question:
51
- Context: {context}
52
- Question: {question}
53
- Answer:",
54
- input_variables=["context", "question"]
55
- )
56
 
57
- chain = RetrievalQA.from_chain_type(
58
  llm=llm,
59
  retriever=retriever,
60
- return_source_documents=False,
61
- chain_type_kwargs={"prompt": prompt}
62
  )
63
- return chain({"query": query})["result"]
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
2
  from langchain_community.vectorstores import FAISS
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.llms import HuggingFaceHub
 
5
  from langchain.prompts import PromptTemplate
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.document_loaders import TextLoader
9
+ from langchain.docstore.document import Document
10
+
11
+ # Load Hugging Face API token from environment
12
+ HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
13
 
14
+ # Embedding model (can be changed to any sentence transformer model)
15
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
16
 
17
+ # Prompt template for Mistral
18
+ prompt_template = PromptTemplate(
19
+ input_variables=["context", "question"],
20
+ template="""You are an intelligent assistant. Use the context below to answer the question.
21
+ If the answer is not contained in the context, say "I don't know."
22
 
23
+ Context: {context}
24
+ Question: {question}
25
+ Answer:"""
26
+ )
27
 
28
+ def create_vectorstore(doc_path: str = "data/docs.txt"):
29
+ """Create or load FAISS vectorstore from the given document."""
30
+ loader = TextLoader(doc_path)
31
+ documents = loader.load()
32
 
33
+ # Split into smaller chunks
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
35
+ docs = text_splitter.split_documents(documents)
 
 
 
 
 
36
 
37
+ # Create FAISS vectorstore
38
+ vectordb = FAISS.from_documents(docs, embedding_model)
39
+ vectordb.save_local("vectorstore")
40
+ return vectordb
41
 
42
+ def load_vectorstore():
43
+ """Load existing FAISS vectorstore from disk."""
44
+ return FAISS.load_local("vectorstore", embedding_model, allow_dangerous_deserialization=True)
 
45
 
46
+ def get_llm():
47
+ """Load the HuggingFace Mistral LLM."""
48
+ return HuggingFaceHub(
49
  repo_id="mistralai/Mistral-7B-Instruct-v0.1",
50
+ model_kwargs={"temperature": 0.5, "max_new_tokens": 512},
51
+ huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
52
  )
53
 
54
+ def build_qa_chain():
55
+ """Build the full RAG QA chain."""
56
+ vectordb = load_vectorstore()
57
+ retriever = vectordb.as_retriever()
58
+ llm = get_llm()
 
 
 
59
 
60
+ qa_chain = RetrievalQA.from_chain_type(
61
  llm=llm,
62
  retriever=retriever,
63
+ return_source_documents=True,
64
+ chain_type_kwargs={"prompt": prompt_template}
65
  )
66
+ return qa_chain
67
+
68
+ def ask_question(query: str) -> dict:
69
+ """Handle a single user query."""
70
+ chain = build_qa_chain()
71
+ result = chain({"query": query})
72
+ return {
73
+ "answer": result["result"],
74
+ "sources": [doc.metadata.get("source", "unknown") for doc in result["source_documents"]]
75
+ }