samim2024 commited on
Commit
2cfdb3c
·
verified ·
1 Parent(s): 6867b0c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +52 -21
model.py CHANGED
@@ -1,48 +1,79 @@
1
  import os
2
  import tempfile
 
3
  from langchain.vectorstores import FAISS
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
- from langchain_community.document_loaders import TextLoader
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain.docstore.document import Document
8
  from langchain.chains import RetrievalQA
9
- from langchain_community.llms import HuggingFaceHub
10
- from langchain.embeddings.base import Embeddings
11
 
12
- # Use /tmp for writeable cache
13
  CACHE_DIR = tempfile.gettempdir()
14
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
15
  os.environ["HF_HOME"] = CACHE_DIR
16
 
17
  DATA_PATH = "/app/data"
18
  VECTORSTORE_PATH = "/app/vectorstore"
19
- DOCS_FILENAME = "context.txt"
20
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"
21
 
22
- def load_embedding_model() -> Embeddings:
 
23
  return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
24
 
25
- def load_documents() -> list[Document]:
26
- loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME))
27
- raw_docs = loader.load()
28
- splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
29
- return splitter.split_documents(raw_docs)
 
 
 
 
 
 
 
 
 
30
 
31
- def load_vectorstore() -> FAISS:
 
32
  vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index")
33
  embedding_model = load_embedding_model()
 
34
  if os.path.exists(vectorstore_file):
35
- return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True)
36
- docs = load_documents()
 
 
 
 
37
  vectorstore = FAISS.from_documents(docs, embedding_model)
38
  vectorstore.save_local(vectorstore_file)
39
  return vectorstore
40
 
41
- def ask_question(query: str) -> str:
42
- vectorstore = load_vectorstore()
43
- llm = HuggingFaceHub(
44
- repo_id="mistralai/Mistral-7B-Instruct-v0.1",
45
- model_kwargs={"temperature": 0.5, "max_tokens": 256},
 
 
 
 
 
 
 
46
  )
47
- qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
48
- return qa.run(query)
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import tempfile
3
+ import PyPDF2
4
  from langchain.vectorstores import FAISS
5
  from langchain_huggingface import HuggingFaceEmbeddings
 
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain.docstore.document import Document
8
  from langchain.chains import RetrievalQA
9
+ from langchain_huggingface import HuggingFaceEndpoint
 
10
 
11
+ # Use /tmp for cache
12
  CACHE_DIR = tempfile.gettempdir()
13
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
14
  os.environ["HF_HOME"] = CACHE_DIR
15
 
16
  DATA_PATH = "/app/data"
17
  VECTORSTORE_PATH = "/app/vectorstore"
 
18
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"
19
 
20
+ def load_embedding_model():
21
+ """Load sentence transformer embeddings."""
22
  return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
23
 
24
+ def load_documents(pdf_path):
25
+ """Extract text from PDF and split into documents."""
26
+ try:
27
+ with open(pdf_path, "rb") as f:
28
+ pdf = PyPDF2.PdfReader(f)
29
+ text = "".join(page.extract_text() or "" for page in pdf.pages)
30
+ if not text.strip():
31
+ raise ValueError("No text extracted from PDF")
32
+
33
+ splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
34
+ docs = splitter.create_documents([text])
35
+ return docs
36
+ except Exception as e:
37
+ raise ValueError(f"Failed to process PDF: {str(e)}")
38
 
39
+ def load_vectorstore(pdf_path):
40
+ """Load or create FAISS vector store from PDF."""
41
  vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index")
42
  embedding_model = load_embedding_model()
43
+
44
  if os.path.exists(vectorstore_file):
45
+ try:
46
+ return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True)
47
+ except:
48
+ pass # Rebuild if loading fails
49
+
50
+ docs = load_documents(pdf_path)
51
  vectorstore = FAISS.from_documents(docs, embedding_model)
52
  vectorstore.save_local(vectorstore_file)
53
  return vectorstore
54
 
55
+ def ask_question(query, pdf_path):
56
+ """Run RAG query and return answer with contexts."""
57
+ api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN")
58
+ if not api_key:
59
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
60
+
61
+ vectorstore = load_vectorstore(pdf_path)
62
+ llm = HuggingFaceEndpoint(
63
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
64
+ huggingfacehub_api_token=api_key,
65
+ temperature=0.5,
66
+ max_new_tokens=256
67
  )
68
+
69
+ qa = RetrievalQA.from_chain_type(
70
+ llm=llm,
71
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
72
+ return_source_documents=True
73
+ )
74
+
75
+ result = qa({"query": query})
76
+ return {
77
+ "answer": result["result"],
78
+ "contexts": [doc.page_content for doc in result["source_documents"]]
79
+ }