SysModeler commited on
Commit
599e88a
·
verified ·
1 Parent(s): d00caa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -51
app.py CHANGED
@@ -1,76 +1,250 @@
1
  import os
2
  import warnings
3
- import gradio as gr
4
  from dotenv import load_dotenv
5
-
6
- from langchain.chains import ConversationalRetrievalChain
 
 
 
 
 
 
 
 
7
  from langchain_community.vectorstores import FAISS
8
  from langchain_community.embeddings import AzureOpenAIEmbeddings
9
- from langchain_community.chat_models import AzureChatOpenAI
10
-
11
- # Patch Gradio bug
12
- import gradio_client.utils
13
- gradio_client.utils.json_schema_to_python_type = lambda schema, defs=None: "string"
14
-
15
  # Load environment variables
16
  load_dotenv()
17
  AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
18
  AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
19
- #AZURE_END_POINT_O3 = os.getenv("AZURE_END_POINT_O3")
20
  AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
21
  AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
22
-
23
-
24
- if not all([AZURE_OPENAI_API_KEY,
25
- AZURE_OPENAI_ENDPOINT,
26
- #AZURE_END_POINT_O3,
27
- AZURE_OPENAI_LLM_DEPLOYMENT,
28
- AZURE_OPENAI_EMBEDDING_DEPLOYMENT]):
29
  raise ValueError("Missing one or more Azure OpenAI environment variables.")
30
-
31
- # Suppress warnings
32
  warnings.filterwarnings("ignore")
33
-
34
- # Initialize Azure embedding model
35
  embeddings = AzureOpenAIEmbeddings(
36
  azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
37
  azure_endpoint=AZURE_OPENAI_ENDPOINT,
38
- #azure_endpoint=AZURE_END_POINT_O3,
39
  openai_api_key=AZURE_OPENAI_API_KEY,
40
- openai_api_version="2025-01-01-preview", # updated to latest recommended version
41
  chunk_size=1000
42
  )
43
-
44
- # Load FAISS vector store
45
- vectorstore = FAISS.load_local(
46
- "faiss_index_sysml", embeddings, allow_dangerous_deserialization=True
47
- )
48
-
49
- # Initialize Azure chat model
50
- llm = AzureChatOpenAI(
51
- deployment_name=AZURE_OPENAI_LLM_DEPLOYMENT,
52
- azure_endpoint=AZURE_OPENAI_ENDPOINT,
53
- #azure_endpoint=AZURE_END_POINT_O3,
54
- openai_api_key=AZURE_OPENAI_API_KEY,
55
- openai_api_version="2025-01-01-preview", # updated to latest recommended version
56
- #temperature=0.5
57
- )
58
-
59
- # Build conversational RAG chain
60
- qa = ConversationalRetrievalChain.from_llm(
61
- llm=llm,
62
- retriever=vectorstore.as_retriever(),
63
- return_source_documents=False
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- history = []
67
-
68
- # Chatbot logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def sysml_chatbot(message, history):
70
- result = qa({"question": message, "chat_history": history})
71
- answer = result["answer"]
72
- history.append((message, answer))
73
- return "", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Gradio UI
76
  with gr.Blocks() as demo:
 
1
  import os
2
  import warnings
3
+ import json
4
  from dotenv import load_dotenv
5
+ from typing import Dict, Any, List, Optional
6
+ import time
7
+ from functools import lru_cache
8
+ import logging
9
+
10
+
11
+
12
+ from langchain.agents import Tool, AgentExecutor
13
+ from langchain.tools.retriever import create_retriever_tool
14
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
15
  from langchain_community.vectorstores import FAISS
16
  from langchain_community.embeddings import AzureOpenAIEmbeddings
17
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
18
+ from openai import AzureOpenAI
19
+
 
 
 
20
  # Load environment variables
21
  load_dotenv()
22
  AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
23
  AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
 
24
  AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
25
  AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
26
+
27
+ if not all([AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_LLM_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]):
 
 
 
 
 
28
  raise ValueError("Missing one or more Azure OpenAI environment variables.")
29
+
 
30
  warnings.filterwarnings("ignore")
31
+
32
+ # Embeddings for retriever
33
  embeddings = AzureOpenAIEmbeddings(
34
  azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
35
  azure_endpoint=AZURE_OPENAI_ENDPOINT,
 
36
  openai_api_key=AZURE_OPENAI_API_KEY,
37
+ openai_api_version="2025-01-01-preview",
38
  chunk_size=1000
39
  )
40
+
41
+ # Get the directory where this script is located
42
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
43
+
44
+ # Build the absolute path to the faiss_index_sysml directory relative to this script
45
+ FAISS_INDEX_PATH = os.path.join(SCRIPT_DIR, "faiss_index_sysml")
46
+ # Load FAISS vectorstore
47
+ vectorstore = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
48
+
49
+ # Initialize Azure OpenAI client directly
50
+ client = AzureOpenAI(
51
+ api_key=AZURE_OPENAI_API_KEY,
52
+ api_version="2025-01-01-preview",
53
+ azure_endpoint=AZURE_OPENAI_ENDPOINT
 
 
 
 
 
 
 
54
  )
55
+
56
+
57
+ logger = logging.getLogger(__name__)
58
+ # SysML retriever function
59
+ @lru_cache(maxsize=100)
60
+ def sysml_retriever(query: str) -> str:
61
+ start_time = time.time()
62
+ try:
63
+ results = vectorstore.similarity_search(query, k=100)
64
+ contexts = [doc.page_content for doc in results]
65
+ response = "\n\n".join(contexts)
66
+
67
+ # Log performance metrics
68
+ duration = time.time() - start_time
69
+ print(f"Retrieval completed in {duration:.2f}s for query: {query[:50]}...")
70
+
71
+ return response
72
+ except Exception as e:
73
+ logger.error(f"Retrieval error: {str(e)}")
74
+ return "Unable to retrieve information at this time."
75
+
76
+
77
+ # sysml_retriever = create_retriever_tool(
78
+ # retriever=vectorstore.as_retriever(),
79
+ # name="SysMLRetriever",
80
+ # description="Use this to answer questions about SysML diagrams and modeling."
81
+ # )
82
 
83
+ # Dummy functions
84
+ def dummy_weather_lookup(location: str = "London") -> str:
85
+ return f"The weather in {location} is sunny and 25°C."
86
+
87
+ def dummy_time_lookup(timezone: str = "UTC") -> str:
88
+ return f"The current time in {timezone} is 3:00 PM."
89
+
90
+ # Tools definition for OpenAI function calling
91
+ tools_definition = [
92
+ {
93
+ "type": "function",
94
+ "function": {
95
+ "name": "SysMLRetriever",
96
+ "description": "Use this to answer questions about SysML diagrams and modeling.",
97
+ "parameters": {
98
+ "type": "object",
99
+ "properties": {
100
+ "query": {
101
+ "type": "string",
102
+ "description": "The search query to find information about SysML"
103
+ }
104
+ },
105
+ "required": ["query"]
106
+ }
107
+ }
108
+ },
109
+ {
110
+ "type": "function",
111
+ "function": {
112
+ "name": "WeatherLookup",
113
+ "description": "Use this to look up the current weather in a specified location.",
114
+ "parameters": {
115
+ "type": "object",
116
+ "properties": {
117
+ "location": {
118
+ "type": "string",
119
+ "description": "The location to look up the weather for"
120
+ }
121
+ },
122
+ "required": ["location"]
123
+ }
124
+ },
125
+ },
126
+ {
127
+ "type": "function",
128
+ "function": {
129
+ "name": "TimeLookup",
130
+ "description": "Use this to look up the current time in a specified timezone.",
131
+ "parameters": {
132
+ "type": "object",
133
+ "properties": {
134
+ "timezone": {
135
+ "type": "string",
136
+ "description": "The timezone to look up the current time for"
137
+ }
138
+ },
139
+ "required": ["timezone"]
140
+ }
141
+ }
142
+ }
143
+ ]
144
+
145
+ # Tool execution mapping
146
+ tool_mapping = {
147
+ "SysMLRetriever": sysml_retriever,
148
+ "WeatherLookup": dummy_weather_lookup,
149
+ "TimeLookup": dummy_time_lookup
150
+ }
151
+
152
+ # Convert chat history
153
+ def convert_history_to_messages(history):
154
+ messages = []
155
+ for user, bot in history:
156
+ messages.append({"role": "user", "content": user})
157
+ messages.append({"role": "assistant", "content": bot})
158
+ return messages
159
+
160
+ # Main chatbot function with direct function calling
161
  def sysml_chatbot(message, history):
162
+ # Convert history to messages format
163
+ chat_messages = convert_history_to_messages(history)
164
+
165
+ # Add system message at beginning
166
+ full_messages = [
167
+ {"role": "system", "content": "You are a helpful SysML modeling assistant and also a capable smart Assistant "}
168
+ ]
169
+ full_messages.extend(chat_messages)
170
+
171
+ # Add current user message
172
+ full_messages.append({"role": "user", "content": message})
173
+
174
+ try:
175
+ # First call to get either a direct answer or a function call
176
+ response = client.chat.completions.create(
177
+ model=AZURE_OPENAI_LLM_DEPLOYMENT,
178
+ messages=full_messages,
179
+ tools=tools_definition,
180
+ tool_choice={"type": "function", "function": {"name": "SysMLRetriever"}}
181
+ )
182
+
183
+ assistant_message = response.choices[0].message
184
+
185
+ # Check if the model wants to call a function
186
+ if assistant_message.tool_calls:
187
+ # Get the function call details
188
+ tool_call = assistant_message.tool_calls[0]
189
+ function_name = tool_call.function.name
190
+ function_args = json.loads(tool_call.function.arguments)
191
+ print("Attempting function calling...")
192
+ # Execute the function
193
+ if function_name in tool_mapping:
194
+ function_response = tool_mapping[function_name](**function_args)
195
+
196
+ # Append the assistant's request and the function response to messages
197
+ full_messages.append({"role": "assistant", "content": None, "tool_calls": [
198
+ {"id": tool_call.id, "type": "function", "function": {"name": function_name, "arguments": tool_call.function.arguments}}
199
+ ]})
200
+
201
+ full_messages.append({
202
+ "role": "tool",
203
+ "tool_call_id": tool_call.id,
204
+ "content": function_response
205
+ })
206
+
207
+ # Second call to get the final answer based on the function result
208
+ second_response = client.chat.completions.create(
209
+ model=AZURE_OPENAI_LLM_DEPLOYMENT,
210
+ messages=full_messages
211
+ )
212
+
213
+ answer = second_response.choices[0].message.content
214
+ print("Getting final response after function execution...")
215
+ print(f"Function '{function_name}' executed successfully. Response: {answer}")
216
+ else:
217
+ answer = f"I tried to use a function '{function_name}' that's not available. Let me try again with general knowledge: SysML is a modeling language for systems engineering that helps visualize and analyze complex systems."
218
+ else:
219
+ # Model provided a direct answer
220
+ answer = assistant_message.content
221
+
222
+ history.append((message, answer))
223
+ return answer, history
224
+
225
+ except Exception as e:
226
+ print(f"Error in function calling: {str(e)}")
227
+
228
+ # Fallback to a direct response without function calling
229
+ try:
230
+ simple_messages = [
231
+ {"role": "system", "content": "You are a helpful SysML modeling assistant."}
232
+ ]
233
+ simple_messages.extend(chat_messages)
234
+ simple_messages.append({"role": "user", "content": message})
235
+
236
+ fallback_response = client.chat.completions.create(
237
+ model=AZURE_OPENAI_LLM_DEPLOYMENT,
238
+ messages=simple_messages
239
+ )
240
+
241
+ answer = fallback_response.choices[0].message.content
242
+ except Exception as fallback_error:
243
+ print(f"Error in fallback: {str(fallback_error)}")
244
+ answer = "I'm having trouble accessing my tools right now. SysML is a modeling language used in systems engineering to visualize and analyze complex systems through various diagram types."
245
+
246
+ history.append((message, answer))
247
+ return answer, history
248
 
249
  # Gradio UI
250
  with gr.Blocks() as demo: