Spaces:
Sleeping
Sleeping
File size: 7,500 Bytes
fc3a249 e12b72c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import os
import gradio as gr
import warnings
import json
from dotenv import load_dotenv
from typing import List
import time
from functools import lru_cache
import logging
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import AzureOpenAIEmbeddings
from openai import AzureOpenAI
# Patch Gradio bug
import gradio_client.utils
gradio_client.utils.json_schema_to_python_type = lambda schema, defs=None: "string"
# Load environment variables
load_dotenv()
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if not all([AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_LLM_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]):
raise ValueError("Missing one or more Azure OpenAI environment variables.")
warnings.filterwarnings("ignore")
# Embeddings
embeddings = AzureOpenAIEmbeddings(
azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
openai_api_key=AZURE_OPENAI_API_KEY,
openai_api_version="2025-01-01-preview",
chunk_size=1000
)
# Vectorstore
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
FAISS_INDEX_PATH = os.path.join(SCRIPT_DIR, "faiss_index_sysml")
vectorstore = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
# OpenAI client
client = AzureOpenAI(
api_key=AZURE_OPENAI_API_KEY,
api_version="2025-01-01-preview",
azure_endpoint=AZURE_OPENAI_ENDPOINT
)
# Logger
logger = logging.getLogger(__name__)
# SysML retriever function
@lru_cache(maxsize=100)
def sysml_retriever(query: str) -> str:
try:
results = vectorstore.similarity_search(query, k=100)
contexts = [doc.page_content for doc in results]
return "\n\n".join(contexts)
except Exception as e:
logger.error(f"Retrieval error: {str(e)}")
return "Unable to retrieve information at this time."
# Dummy functions
def dummy_weather_lookup(location: str = "London") -> str:
return f"The weather in {location} is sunny and 25°C."
def dummy_time_lookup(timezone: str = "UTC") -> str:
return f"The current time in {timezone} is 3:00 PM."
# Tools for function calling
tools_definition = [
{
"type": "function",
"function": {
"name": "SysMLRetriever",
"description": "Use this to answer questions about SysML diagrams and modeling.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "The search query to find information about SysML"}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "WeatherLookup",
"description": "Use this to look up the current weather in a specified location.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The location to look up the weather for"}
},
"required": ["location"]
}
}
},
{
"type": "function",
"function": {
"name": "TimeLookup",
"description": "Use this to look up the current time in a specified timezone.",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string", "description": "The timezone to look up the current time for"}
},
"required": ["timezone"]
}
}
}
]
# Tool execution mapping
tool_mapping = {
"SysMLRetriever": sysml_retriever,
"WeatherLookup": dummy_weather_lookup,
"TimeLookup": dummy_time_lookup
}
# Convert chat history
def convert_history_to_messages(history):
messages = []
for user, bot in history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": bot})
return messages
# Chatbot logic
def sysml_chatbot(message, history):
chat_messages = convert_history_to_messages(history)
full_messages = [
{"role": "system", "content": "You are a helpful SysML modeling assistant and also a capable smart Assistant"}
] + chat_messages + [{"role": "user", "content": message}]
try:
response = client.chat.completions.create(
model=AZURE_OPENAI_LLM_DEPLOYMENT,
messages=full_messages,
tools=tools_definition,
tool_choice={"type": "function", "function": {"name": "SysMLRetriever"}}
)
assistant_message = response.choices[0].message
if assistant_message.tool_calls:
tool_call = assistant_message.tool_calls[0]
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if function_name in tool_mapping:
function_response = tool_mapping[function_name](**function_args)
full_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tool_call.id,
"type": "function",
"function": {
"name": function_name,
"arguments": tool_call.function.arguments
}
}]
})
full_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": function_response
})
second_response = client.chat.completions.create(
model=AZURE_OPENAI_LLM_DEPLOYMENT,
messages=full_messages
)
answer = second_response.choices[0].message.content
else:
answer = f"I tried to use a function '{function_name}' that's not available."
else:
answer = assistant_message.content
history.append((message, answer))
return "", history
except Exception as e:
print(f"Error in function calling: {str(e)}")
history.append((message, "Sorry, something went wrong."))
return "", history
# === Gradio UI ===
with gr.Blocks(css="""
#submit-btn {
height: 100%;
background-color: #48CAE4;
color: white;
font-size: 1.5em;
}
""") as demo:
gr.Markdown("## SysModeler Chatbot")
chatbot = gr.Chatbot(height=600)
with gr.Row():
with gr.Column(scale=5):
msg = gr.Textbox(
placeholder="Ask me about SysML diagrams or concepts...",
lines=3,
show_label=False
)
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button("➤", elem_id="submit-btn")
clear = gr.Button("Clear")
state = gr.State([])
submit_btn.click(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot])
msg.submit(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot])
clear.click(fn=lambda: ([], ""), inputs=None, outputs=[chatbot, msg])
if __name__ == "__main__":
demo.launch() |