samim2024 commited on
Commit
a3c88c4
·
verified ·
1 Parent(s): 7a3ba0e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -39
model.py CHANGED
@@ -1,70 +1,48 @@
1
  import os
2
- from typing import List
3
  from langchain.vectorstores import FAISS
4
- from langchain.embeddings.base import Embeddings
5
  from langchain_community.document_loaders import TextLoader
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain.docstore.document import Document
8
  from langchain.chains import RetrievalQA
9
  from langchain_community.llms import HuggingFaceHub
10
- from langchain_huggingface import HuggingFaceEmbeddings
11
 
12
- # Configure safe cache directories (writable within container)
13
- CACHE_DIR = "/tmp/huggingface"
14
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
15
  os.environ["HF_HOME"] = CACHE_DIR
16
- os.makedirs(CACHE_DIR, exist_ok=True)
17
 
18
- # Constants
19
  DATA_PATH = "/app/data"
20
  VECTORSTORE_PATH = "/app/vectorstore"
21
  DOCS_FILENAME = "context.txt"
22
- VECTORSTORE_INDEX_NAME = "faiss_index"
23
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"
24
- LLM_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.1"
25
-
26
 
27
  def load_embedding_model() -> Embeddings:
28
- """Load Hugging Face sentence transformer embeddings."""
29
  return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
30
 
31
-
32
- def load_documents() -> List[Document]:
33
- """Load documents and split them into manageable chunks."""
34
  loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME))
35
- documents = loader.load()
36
-
37
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
38
- return splitter.split_documents(documents)
39
-
40
 
41
  def load_vectorstore() -> FAISS:
42
- """Load FAISS vectorstore from disk or create it from documents."""
43
- vectorstore_dir = os.path.join(VECTORSTORE_PATH, VECTORSTORE_INDEX_NAME)
44
  embedding_model = load_embedding_model()
45
-
46
- if os.path.exists(vectorstore_dir):
47
- return FAISS.load_local(
48
- folder_path=vectorstore_dir,
49
- embeddings=embedding_model,
50
- allow_dangerous_deserialization=True,
51
- )
52
-
53
- documents = load_documents()
54
- vectorstore = FAISS.from_documents(documents, embedding_model)
55
- vectorstore.save_local(vectorstore_dir)
56
  return vectorstore
57
 
58
-
59
  def ask_question(query: str) -> str:
60
- """Run a question-answering chain with the retriever and language model."""
61
  vectorstore = load_vectorstore()
62
- retriever = vectorstore.as_retriever()
63
-
64
  llm = HuggingFaceHub(
65
- repo_id=LLM_REPO_ID,
66
  model_kwargs={"temperature": 0.5, "max_tokens": 256},
67
  )
68
-
69
- qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
70
- return qa_chain.run(query)
 
1
  import os
2
+ import tempfile
3
  from langchain.vectorstores import FAISS
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_community.document_loaders import TextLoader
6
  from langchain.text_splitter import CharacterTextSplitter
7
  from langchain.docstore.document import Document
8
  from langchain.chains import RetrievalQA
9
  from langchain_community.llms import HuggingFaceHub
10
+ from langchain.embeddings.base import Embeddings
11
 
12
+ # Use /tmp for writeable cache
13
+ CACHE_DIR = tempfile.gettempdir()
14
  os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
15
  os.environ["HF_HOME"] = CACHE_DIR
 
16
 
 
17
  DATA_PATH = "/app/data"
18
  VECTORSTORE_PATH = "/app/vectorstore"
19
  DOCS_FILENAME = "context.txt"
 
20
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"
 
 
21
 
22
  def load_embedding_model() -> Embeddings:
 
23
  return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
24
 
25
+ def load_documents() -> list[Document]:
 
 
26
  loader = TextLoader(os.path.join(DATA_PATH, DOCS_FILENAME))
27
+ raw_docs = loader.load()
 
28
  splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
29
+ return splitter.split_documents(raw_docs)
 
30
 
31
  def load_vectorstore() -> FAISS:
32
+ vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index")
 
33
  embedding_model = load_embedding_model()
34
+ if os.path.exists(vectorstore_file):
35
+ return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True)
36
+ docs = load_documents()
37
+ vectorstore = FAISS.from_documents(docs, embedding_model)
38
+ vectorstore.save_local(vectorstore_file)
 
 
 
 
 
 
39
  return vectorstore
40
 
 
41
  def ask_question(query: str) -> str:
 
42
  vectorstore = load_vectorstore()
 
 
43
  llm = HuggingFaceHub(
44
+ repo_id="mistralai/Mistral-7B-Instruct-v0.1",
45
  model_kwargs={"temperature": 0.5, "max_tokens": 256},
46
  )
47
+ qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
48
+ return qa.run(query)