TechyCode commited on
Commit
13ee6ba
·
verified ·
1 Parent(s): 0bd7570

Update src/rag_methods.py

Browse files
Files changed (1) hide show
  1. src/rag_methods.py +48 -61
src/rag_methods.py CHANGED
@@ -2,6 +2,21 @@ import os
2
  import dotenv
3
  from time import time
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from langchain_community.document_loaders.text import TextLoader
7
  from langchain_community.document_loaders import (
@@ -21,7 +36,17 @@ dotenv.load_dotenv()
21
  os.environ["USER_AGENT"] = "myagent"
22
  DB_DOCS_LIMIT = 10
23
 
24
- # Stream non-RAG LLM response
 
 
 
 
 
 
 
 
 
 
25
  def stream_llm_response(llm_stream, messages):
26
  response_message = ""
27
  for chunk in llm_stream.stream(messages):
@@ -29,18 +54,17 @@ def stream_llm_response(llm_stream, messages):
29
  yield chunk
30
  st.session_state.messages.append({"role": "assistant", "content": response_message})
31
 
32
- # --- Document Loading and Indexing ---
33
  def load_doc_to_db():
34
  if "rag_docs" in st.session_state and st.session_state.rag_docs:
35
  docs = []
36
  for doc_file in st.session_state.rag_docs:
37
  if doc_file.name not in st.session_state.rag_sources:
38
  if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
39
- os.makedirs("source_files", exist_ok=True)
40
- file_path = f"./source_files/{doc_file.name}"
41
- with open(file_path, "wb") as file:
42
- file.write(doc_file.read())
43
  try:
 
 
 
 
44
  if doc_file.type == "application/pdf":
45
  loader = PyPDFLoader(file_path)
46
  elif doc_file.name.endswith(".docx"):
@@ -50,17 +74,22 @@ def load_doc_to_db():
50
  else:
51
  st.warning(f"Unsupported document type: {doc_file.type}")
52
  continue
 
53
  docs.extend(loader.load())
54
  st.session_state.rag_sources.append(doc_file.name)
 
55
  except Exception as e:
56
- st.toast(f"Error loading document {doc_file.name}: {e}", icon="⚠️")
 
57
  finally:
58
- os.remove(file_path)
 
59
  else:
60
  st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
61
  if docs:
62
  _split_and_load_docs(docs)
63
- st.toast(f"Documents loaded successfully.", icon="✅")
 
64
 
65
  def load_url_to_db():
66
  if "rag_url" in st.session_state and st.session_state.rag_url:
@@ -72,8 +101,10 @@ def load_url_to_db():
72
  loader = WebBaseLoader(url)
73
  docs.extend(loader.load())
74
  st.session_state.rag_sources.append(url)
 
75
  except Exception as e:
76
- st.error(f"Error loading from URL {url}: {e}")
 
77
  if docs:
78
  _split_and_load_docs(docs)
79
  st.toast(f"Loaded content from URL: {url}", icon="✅")
@@ -81,18 +112,16 @@ def load_url_to_db():
81
  st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
82
 
83
  def initialize_vector_db(docs):
84
- # Initialize HuggingFace embeddings
85
  embedding = HuggingFaceEmbeddings(
86
  model_name="BAAI/bge-large-en-v1.5",
87
  model_kwargs={'device': 'cpu'},
88
- encode_kwargs={'normalize_embeddings': False}
 
89
  )
90
 
91
- # Shared persistent directory for long-term storage
92
- persist_dir = "./chroma_persistent_db"
93
  collection_name = "persistent_collection"
94
 
95
- # Create the persistent Chroma vector store
96
  vector_db = Chroma.from_documents(
97
  documents=docs,
98
  embedding=embedding,
@@ -100,12 +129,10 @@ def initialize_vector_db(docs):
100
  collection_name=collection_name
101
  )
102
 
103
- # Persist to disk
104
  vector_db.persist()
105
-
106
  return vector_db
107
 
108
-
109
  def _split_and_load_docs(docs):
110
  text_splitter = RecursiveCharacterTextSplitter(
111
  chunk_size=1000,
@@ -118,9 +145,8 @@ def _split_and_load_docs(docs):
118
  st.session_state.vector_db = initialize_vector_db(chunks)
119
  else:
120
  st.session_state.vector_db.add_documents(chunks)
121
- st.session_state.vector_db.persist() # Save changes
122
-
123
- # --- RAG Chain ---
124
 
125
  def _get_context_retriever_chain(vector_db, llm):
126
  retriever = vector_db.as_retriever()
@@ -132,43 +158,4 @@ def _get_context_retriever_chain(vector_db, llm):
132
  return create_history_aware_retriever(llm, retriever, prompt)
133
 
134
  def get_conversational_rag_chain(llm):
135
- retriever_chain = _get_context_retriever_chain(st.session_state.vector_db, llm)
136
- prompt = ChatPromptTemplate.from_messages([
137
- ("system",
138
- """You are a helpful assistant answering the user's queries using the provided context if available.\n
139
- {context}"""),
140
- MessagesPlaceholder(variable_name="messages"),
141
- ("user", "{input}")
142
- ])
143
- stuff_documents_chain = create_stuff_documents_chain(llm, prompt)
144
- return create_retrieval_chain(retriever_chain, stuff_documents_chain)
145
-
146
- # Stream RAG LLM response
147
- def stream_llm_rag_response(llm_stream, messages):
148
- rag_chain = get_conversational_rag_chain(llm_stream)
149
-
150
- # Extract latest user input and prior messages
151
- input_text = messages[-1].content
152
- history = messages[:-1]
153
-
154
- # --- DEBUG: Show context retrieved ---
155
- if st.session_state.get("debug_mode"):
156
- retriever = st.session_state.vector_db.as_retriever()
157
- retrieved_docs = retriever.get_relevant_documents(input_text)
158
- st.markdown("### 🔍 Retrieved Context (Debug Mode)")
159
- for i, doc in enumerate(retrieved_docs):
160
- st.markdown(f"**Chunk {i+1}:**\n```\n{doc.page_content.strip()}\n```")
161
-
162
- response_message = "*(RAG Response)*\n"
163
- response = rag_chain.stream({
164
- "messages": history,
165
- "input": input_text
166
- })
167
-
168
- for chunk in response:
169
- if 'answer' in chunk:
170
- response_message += chunk['answer']
171
- yield chunk['answer']
172
-
173
- st.session_state.messages.append({"role": "assistant", "content": response_message})
174
-
 
2
  import dotenv
3
  from time import time
4
  import streamlit as st
5
+ import logging
6
+
7
+ # Configure environment for Hugging Face Spaces
8
+ os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface"
10
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface"
11
+
12
+ # Create necessary directories
13
+ os.makedirs("/tmp/.cache/huggingface", exist_ok=True)
14
+ os.makedirs("/tmp/chroma_persistent_db", exist_ok=True)
15
+ os.makedirs("/tmp/source_files", exist_ok=True)
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
  from langchain_community.document_loaders.text import TextLoader
22
  from langchain_community.document_loaders import (
 
36
  os.environ["USER_AGENT"] = "myagent"
37
  DB_DOCS_LIMIT = 10
38
 
39
+ def clean_temp_files():
40
+ """Clean up temporary files to prevent storage issues"""
41
+ try:
42
+ for folder in ["/tmp/source_files"]:
43
+ for filename in os.listdir(folder):
44
+ file_path = os.path.join(folder, filename)
45
+ if os.path.isfile(file_path):
46
+ os.unlink(file_path)
47
+ except Exception as e:
48
+ logger.warning(f"Error cleaning temp files: {e}")
49
+
50
  def stream_llm_response(llm_stream, messages):
51
  response_message = ""
52
  for chunk in llm_stream.stream(messages):
 
54
  yield chunk
55
  st.session_state.messages.append({"role": "assistant", "content": response_message})
56
 
 
57
  def load_doc_to_db():
58
  if "rag_docs" in st.session_state and st.session_state.rag_docs:
59
  docs = []
60
  for doc_file in st.session_state.rag_docs:
61
  if doc_file.name not in st.session_state.rag_sources:
62
  if len(st.session_state.rag_sources) < DB_DOCS_LIMIT:
 
 
 
 
63
  try:
64
+ file_path = f"/tmp/source_files/{doc_file.name}"
65
+ with open(file_path, "wb") as file:
66
+ file.write(doc_file.getbuffer())
67
+
68
  if doc_file.type == "application/pdf":
69
  loader = PyPDFLoader(file_path)
70
  elif doc_file.name.endswith(".docx"):
 
74
  else:
75
  st.warning(f"Unsupported document type: {doc_file.type}")
76
  continue
77
+
78
  docs.extend(loader.load())
79
  st.session_state.rag_sources.append(doc_file.name)
80
+ logger.info(f"Successfully loaded document: {doc_file.name}")
81
  except Exception as e:
82
+ st.toast(f"Error loading document {doc_file.name}: {str(e)}", icon="⚠️")
83
+ logger.error(f"Error loading document: {e}")
84
  finally:
85
+ if os.path.exists(file_path):
86
+ os.remove(file_path)
87
  else:
88
  st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
89
  if docs:
90
  _split_and_load_docs(docs)
91
+ st.toast("Documents loaded successfully.", icon="✅")
92
+ clean_temp_files()
93
 
94
  def load_url_to_db():
95
  if "rag_url" in st.session_state and st.session_state.rag_url:
 
101
  loader = WebBaseLoader(url)
102
  docs.extend(loader.load())
103
  st.session_state.rag_sources.append(url)
104
+ logger.info(f"Successfully loaded URL: {url}")
105
  except Exception as e:
106
+ st.error(f"Error loading from URL {url}: {str(e)}")
107
+ logger.error(f"Error loading URL: {e}")
108
  if docs:
109
  _split_and_load_docs(docs)
110
  st.toast(f"Loaded content from URL: {url}", icon="✅")
 
112
  st.error(f"Max documents reached ({DB_DOCS_LIMIT}).")
113
 
114
  def initialize_vector_db(docs):
 
115
  embedding = HuggingFaceEmbeddings(
116
  model_name="BAAI/bge-large-en-v1.5",
117
  model_kwargs={'device': 'cpu'},
118
+ encode_kwargs={'normalize_embeddings': False},
119
+ cache_folder="/tmp/.cache"
120
  )
121
 
122
+ persist_dir = "/tmp/chroma_persistent_db"
 
123
  collection_name = "persistent_collection"
124
 
 
125
  vector_db = Chroma.from_documents(
126
  documents=docs,
127
  embedding=embedding,
 
129
  collection_name=collection_name
130
  )
131
 
 
132
  vector_db.persist()
133
+ logger.info("Vector database initialized and persisted")
134
  return vector_db
135
 
 
136
  def _split_and_load_docs(docs):
137
  text_splitter = RecursiveCharacterTextSplitter(
138
  chunk_size=1000,
 
145
  st.session_state.vector_db = initialize_vector_db(chunks)
146
  else:
147
  st.session_state.vector_db.add_documents(chunks)
148
+ st.session_state.vector_db.persist()
149
+ logger.info("Added new documents to existing vector database")
 
150
 
151
  def _get_context_retriever_chain(vector_db, llm):
152
  retriever = vector_db.as_retriever()
 
158
  return create_history_aware_retriever(llm, retriever, prompt)
159
 
160
  def get_conversational_rag_chain(llm):
161
+ retriever_chain = _get_context_retriever_chain