samim2024 commited on
Commit
982eab4
·
verified ·
1 Parent(s): 404c584

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +63 -0
model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PyPDF2 import PdfReader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.docstore.in_memory import InMemoryDocstore
7
+ from langchain_community.llms import HuggingFaceHub
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.prompts import PromptTemplate
10
+ import uuid
11
+ import faiss
12
+
13
+ vectorstore = None
14
+
15
+ def load_vectorstore(pdf_path):
16
+ global vectorstore
17
+
18
+ reader = PdfReader(pdf_path)
19
+ text = "".join([page.extract_text() or "" for page in reader.pages])
20
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
21
+ chunks = splitter.split_text(text)
22
+
23
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
24
+ dim = len(embeddings.embed_query("test"))
25
+ index = faiss.IndexFlatL2(dim)
26
+
27
+ vectorstore = FAISS(
28
+ embedding_function=embeddings,
29
+ index=index,
30
+ docstore=InMemoryDocstore({}),
31
+ index_to_docstore_id={}
32
+ )
33
+ uuids = [str(uuid.uuid4()) for _ in chunks]
34
+ vectorstore.add_texts(chunks, ids=uuids)
35
+
36
+
37
+ def ask_question(query):
38
+ global vectorstore
39
+ if not vectorstore:
40
+ return "Please upload and index a document first."
41
+
42
+ llm = HuggingFaceHub(
43
+ repo_id="mistralai/Mistral-7B-Instruct-v0.1",
44
+ huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
45
+ model_kwargs={"temperature": 0.7, "max_length": 512}
46
+ )
47
+
48
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
49
+ prompt = PromptTemplate(
50
+ template="Use the context to answer the question:
51
+ Context: {context}
52
+ Question: {question}
53
+ Answer:",
54
+ input_variables=["context", "question"]
55
+ )
56
+
57
+ chain = RetrievalQA.from_chain_type(
58
+ llm=llm,
59
+ retriever=retriever,
60
+ return_source_documents=False,
61
+ chain_type_kwargs={"prompt": prompt}
62
+ )
63
+ return chain({"query": query})["result"]