Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,25 +18,30 @@ from sentence_transformers import CrossEncoder
|
|
18 |
import google.generativeai as genai
|
19 |
from typing import List
|
20 |
from langchain_core.language_models import BaseLanguageModel
|
|
|
21 |
|
22 |
import google.generativeai as genai
|
23 |
|
24 |
|
25 |
-
class GeminiLLM(
|
26 |
def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None):
|
27 |
self.api_key = api_key or st.secrets["GOOGLE_API_KEY"]
|
28 |
if not self.api_key:
|
29 |
-
raise ValueError("GOOGLE_API_KEY not found
|
30 |
genai.configure(api_key=self.api_key)
|
31 |
self.model = genai.GenerativeModel(model_name)
|
32 |
-
|
33 |
-
def _call(self, prompt, stop=None):
|
34 |
response = self.model.generate_content(prompt)
|
35 |
return response.text
|
36 |
|
37 |
@property
|
38 |
-
def _llm_type(self):
|
39 |
return "custom_gemini"
|
|
|
|
|
|
|
|
|
40 |
|
41 |
class GeminiEmbeddings(Embeddings):
|
42 |
def __init__(self, model_name="models/embedding-001", api_key=None):
|
@@ -64,19 +69,6 @@ class GeminiEmbeddings(Embeddings):
|
|
64 |
task_type="retrieval_query"
|
65 |
)["embedding"]
|
66 |
|
67 |
-
|
68 |
-
class GeminiLLM:
|
69 |
-
def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None):
|
70 |
-
api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
71 |
-
if not api_key:
|
72 |
-
raise ValueError("Missing GOOGLE_API_KEY")
|
73 |
-
genai.configure(api_key=api_key)
|
74 |
-
self.model = genai.GenerativeModel(model_name)
|
75 |
-
|
76 |
-
def predict(self, prompt: str) -> str:
|
77 |
-
response = self.model.generate_content(prompt)
|
78 |
-
return response.text.strip()
|
79 |
-
|
80 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
81 |
|
82 |
vectorstore_global = None
|
@@ -88,8 +80,8 @@ def load_environment():
|
|
88 |
def preload_modtran_document():
|
89 |
global vectorstore_global
|
90 |
embeddings = GeminiEmbeddings()
|
91 |
-
vectorstore = FAISS.load_local("modtran_vectorstore", embeddings, allow_dangerous_deserialization=True)
|
92 |
-
set_global_vectorstore(vectorstore)
|
93 |
st.session_state.chat_ready = True
|
94 |
|
95 |
def convert_pdf_to_xml(pdf_file, xml_path):
|
@@ -183,7 +175,7 @@ def self_reasoning(query, context):
|
|
183 |
|
184 |
**Answer:**
|
185 |
"""
|
186 |
-
return llm.
|
187 |
|
188 |
def faiss_search_with_keywords(query):
|
189 |
global vectorstore_global
|
@@ -222,7 +214,7 @@ faiss_reasoning_tool = Tool(
|
|
222 |
)
|
223 |
|
224 |
def initialize_chatbot_agent():
|
225 |
-
llm = GeminiLLM()
|
226 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
227 |
tools = [faiss_keyword_tool, faiss_reasoning_tool]
|
228 |
agent = initialize_agent(
|
@@ -250,12 +242,12 @@ def handle_user_query(query):
|
|
250 |
def main():
|
251 |
load_environment()
|
252 |
|
253 |
-
if "agent" not in st.session_state:
|
254 |
-
st.session_state.agent = None
|
255 |
if "chat_ready" not in st.session_state:
|
256 |
st.session_state.chat_ready = False
|
257 |
if "chat_history" not in st.session_state:
|
258 |
st.session_state.chat_history = []
|
|
|
|
|
259 |
|
260 |
st.header("Chat with MODTRAN Documents π")
|
261 |
|
|
|
18 |
import google.generativeai as genai
|
19 |
from typing import List
|
20 |
from langchain_core.language_models import BaseLanguageModel
|
21 |
+
from langchain_core.runnables import Runnable
|
22 |
|
23 |
import google.generativeai as genai
|
24 |
|
25 |
|
26 |
+
class GeminiLLM(Runnable):
|
27 |
def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None):
|
28 |
self.api_key = api_key or st.secrets["GOOGLE_API_KEY"]
|
29 |
if not self.api_key:
|
30 |
+
raise ValueError("GOOGLE_API_KEY not found.")
|
31 |
genai.configure(api_key=self.api_key)
|
32 |
self.model = genai.GenerativeModel(model_name)
|
33 |
+
|
34 |
+
def _call(self, prompt: str, stop=None) -> str:
|
35 |
response = self.model.generate_content(prompt)
|
36 |
return response.text
|
37 |
|
38 |
@property
|
39 |
+
def _llm_type(self) -> str:
|
40 |
return "custom_gemini"
|
41 |
+
|
42 |
+
def invoke(self, input, config=None):
|
43 |
+
response = self.model.generate_content(input)
|
44 |
+
return response.text.strip()
|
45 |
|
46 |
class GeminiEmbeddings(Embeddings):
|
47 |
def __init__(self, model_name="models/embedding-001", api_key=None):
|
|
|
69 |
task_type="retrieval_query"
|
70 |
)["embedding"]
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
73 |
|
74 |
vectorstore_global = None
|
|
|
80 |
def preload_modtran_document():
|
81 |
global vectorstore_global
|
82 |
embeddings = GeminiEmbeddings()
|
83 |
+
st.session_state.vectorstore = FAISS.load_local("modtran_vectorstore", embeddings, allow_dangerous_deserialization=True)
|
84 |
+
set_global_vectorstore(st.session_state.vectorstore)
|
85 |
st.session_state.chat_ready = True
|
86 |
|
87 |
def convert_pdf_to_xml(pdf_file, xml_path):
|
|
|
175 |
|
176 |
**Answer:**
|
177 |
"""
|
178 |
+
return llm._call(reasoning_prompt)
|
179 |
|
180 |
def faiss_search_with_keywords(query):
|
181 |
global vectorstore_global
|
|
|
214 |
)
|
215 |
|
216 |
def initialize_chatbot_agent():
|
217 |
+
llm = GeminiLLM()
|
218 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
219 |
tools = [faiss_keyword_tool, faiss_reasoning_tool]
|
220 |
agent = initialize_agent(
|
|
|
242 |
def main():
|
243 |
load_environment()
|
244 |
|
|
|
|
|
245 |
if "chat_ready" not in st.session_state:
|
246 |
st.session_state.chat_ready = False
|
247 |
if "chat_history" not in st.session_state:
|
248 |
st.session_state.chat_history = []
|
249 |
+
if "vectorstore" not in st.session_state:
|
250 |
+
st.session_state.vectorstore = None
|
251 |
|
252 |
st.header("Chat with MODTRAN Documents π")
|
253 |
|