ZarinT commited on
Commit
6812c83
Β·
verified Β·
1 Parent(s): d1fa96d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -24
app.py CHANGED
@@ -18,25 +18,30 @@ from sentence_transformers import CrossEncoder
18
  import google.generativeai as genai
19
  from typing import List
20
  from langchain_core.language_models import BaseLanguageModel
 
21
 
22
  import google.generativeai as genai
23
 
24
 
25
- class GeminiLLM(BaseLanguageModel):
26
  def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None):
27
  self.api_key = api_key or st.secrets["GOOGLE_API_KEY"]
28
  if not self.api_key:
29
- raise ValueError("GOOGLE_API_KEY not found in Streamlit secrets.")
30
  genai.configure(api_key=self.api_key)
31
  self.model = genai.GenerativeModel(model_name)
32
-
33
- def _call(self, prompt, stop=None):
34
  response = self.model.generate_content(prompt)
35
  return response.text
36
 
37
  @property
38
- def _llm_type(self):
39
  return "custom_gemini"
 
 
 
 
40
 
41
  class GeminiEmbeddings(Embeddings):
42
  def __init__(self, model_name="models/embedding-001", api_key=None):
@@ -64,19 +69,6 @@ class GeminiEmbeddings(Embeddings):
64
  task_type="retrieval_query"
65
  )["embedding"]
66
 
67
-
68
- class GeminiLLM:
69
- def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None):
70
- api_key = api_key or os.getenv("GOOGLE_API_KEY")
71
- if not api_key:
72
- raise ValueError("Missing GOOGLE_API_KEY")
73
- genai.configure(api_key=api_key)
74
- self.model = genai.GenerativeModel(model_name)
75
-
76
- def predict(self, prompt: str) -> str:
77
- response = self.model.generate_content(prompt)
78
- return response.text.strip()
79
-
80
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
81
 
82
  vectorstore_global = None
@@ -88,8 +80,8 @@ def load_environment():
88
  def preload_modtran_document():
89
  global vectorstore_global
90
  embeddings = GeminiEmbeddings()
91
- vectorstore = FAISS.load_local("modtran_vectorstore", embeddings, allow_dangerous_deserialization=True)
92
- set_global_vectorstore(vectorstore)
93
  st.session_state.chat_ready = True
94
 
95
  def convert_pdf_to_xml(pdf_file, xml_path):
@@ -183,7 +175,7 @@ def self_reasoning(query, context):
183
 
184
  **Answer:**
185
  """
186
- return llm.predict(reasoning_prompt)
187
 
188
  def faiss_search_with_keywords(query):
189
  global vectorstore_global
@@ -222,7 +214,7 @@ faiss_reasoning_tool = Tool(
222
  )
223
 
224
  def initialize_chatbot_agent():
225
- llm = GeminiLLM() # <-- Gemini instead of OpenAI
226
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
227
  tools = [faiss_keyword_tool, faiss_reasoning_tool]
228
  agent = initialize_agent(
@@ -250,12 +242,12 @@ def handle_user_query(query):
250
  def main():
251
  load_environment()
252
 
253
- if "agent" not in st.session_state:
254
- st.session_state.agent = None
255
  if "chat_ready" not in st.session_state:
256
  st.session_state.chat_ready = False
257
  if "chat_history" not in st.session_state:
258
  st.session_state.chat_history = []
 
 
259
 
260
  st.header("Chat with MODTRAN Documents πŸ“„")
261
 
 
18
  import google.generativeai as genai
19
  from typing import List
20
  from langchain_core.language_models import BaseLanguageModel
21
+ from langchain_core.runnables import Runnable
22
 
23
  import google.generativeai as genai
24
 
25
 
26
+ class GeminiLLM(Runnable):
27
  def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None):
28
  self.api_key = api_key or st.secrets["GOOGLE_API_KEY"]
29
  if not self.api_key:
30
+ raise ValueError("GOOGLE_API_KEY not found.")
31
  genai.configure(api_key=self.api_key)
32
  self.model = genai.GenerativeModel(model_name)
33
+
34
+ def _call(self, prompt: str, stop=None) -> str:
35
  response = self.model.generate_content(prompt)
36
  return response.text
37
 
38
  @property
39
+ def _llm_type(self) -> str:
40
  return "custom_gemini"
41
+
42
+ def invoke(self, input, config=None):
43
+ response = self.model.generate_content(input)
44
+ return response.text.strip()
45
 
46
  class GeminiEmbeddings(Embeddings):
47
  def __init__(self, model_name="models/embedding-001", api_key=None):
 
69
  task_type="retrieval_query"
70
  )["embedding"]
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
73
 
74
  vectorstore_global = None
 
80
  def preload_modtran_document():
81
  global vectorstore_global
82
  embeddings = GeminiEmbeddings()
83
+ st.session_state.vectorstore = FAISS.load_local("modtran_vectorstore", embeddings, allow_dangerous_deserialization=True)
84
+ set_global_vectorstore(st.session_state.vectorstore)
85
  st.session_state.chat_ready = True
86
 
87
  def convert_pdf_to_xml(pdf_file, xml_path):
 
175
 
176
  **Answer:**
177
  """
178
+ return llm._call(reasoning_prompt)
179
 
180
  def faiss_search_with_keywords(query):
181
  global vectorstore_global
 
214
  )
215
 
216
  def initialize_chatbot_agent():
217
+ llm = GeminiLLM()
218
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
219
  tools = [faiss_keyword_tool, faiss_reasoning_tool]
220
  agent = initialize_agent(
 
242
  def main():
243
  load_environment()
244
 
 
 
245
  if "chat_ready" not in st.session_state:
246
  st.session_state.chat_ready = False
247
  if "chat_history" not in st.session_state:
248
  st.session_state.chat_history = []
249
+ if "vectorstore" not in st.session_state:
250
+ st.session_state.vectorstore = None
251
 
252
  st.header("Chat with MODTRAN Documents πŸ“„")
253