Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import warnings | |
import json | |
from dotenv import load_dotenv | |
from typing import List | |
import time | |
from functools import lru_cache | |
import logging | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import AzureOpenAIEmbeddings | |
from openai import AzureOpenAI | |
# Patch Gradio bug | |
import gradio_client.utils | |
gradio_client.utils.json_schema_to_python_type = lambda schema, defs=None: "string" | |
# Load environment variables | |
load_dotenv() | |
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") | |
AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") | |
if not all([AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_LLM_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]): | |
raise ValueError("Missing one or more Azure OpenAI environment variables.") | |
warnings.filterwarnings("ignore") | |
# Embeddings | |
embeddings = AzureOpenAIEmbeddings( | |
azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
openai_api_key=AZURE_OPENAI_API_KEY, | |
openai_api_version="2025-01-01-preview", | |
chunk_size=1000 | |
) | |
# Vectorstore | |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
FAISS_INDEX_PATH = os.path.join(SCRIPT_DIR, "faiss_index_sysml") | |
vectorstore = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True) | |
# OpenAI client | |
client = AzureOpenAI( | |
api_key=AZURE_OPENAI_API_KEY, | |
api_version="2025-01-01-preview", | |
azure_endpoint=AZURE_OPENAI_ENDPOINT | |
) | |
# Logger | |
logger = logging.getLogger(__name__) | |
# SysML retriever function | |
def sysml_retriever(query: str) -> str: | |
try: | |
results = vectorstore.similarity_search(query, k=100) | |
contexts = [doc.page_content for doc in results] | |
return "\n\n".join(contexts) | |
except Exception as e: | |
logger.error(f"Retrieval error: {str(e)}") | |
return "Unable to retrieve information at this time." | |
# Dummy functions | |
def dummy_weather_lookup(location: str = "London") -> str: | |
return f"The weather in {location} is sunny and 25°C." | |
def dummy_time_lookup(timezone: str = "UTC") -> str: | |
return f"The current time in {timezone} is 3:00 PM." | |
# Tools for function calling | |
tools_definition = [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "SysMLRetriever", | |
"description": "Use this to answer questions about SysML diagrams and modeling.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": {"type": "string", "description": "The search query to find information about SysML"} | |
}, | |
"required": ["query"] | |
} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "WeatherLookup", | |
"description": "Use this to look up the current weather in a specified location.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": {"type": "string", "description": "The location to look up the weather for"} | |
}, | |
"required": ["location"] | |
} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "TimeLookup", | |
"description": "Use this to look up the current time in a specified timezone.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"timezone": {"type": "string", "description": "The timezone to look up the current time for"} | |
}, | |
"required": ["timezone"] | |
} | |
} | |
} | |
] | |
# Tool execution mapping | |
tool_mapping = { | |
"SysMLRetriever": sysml_retriever, | |
"WeatherLookup": dummy_weather_lookup, | |
"TimeLookup": dummy_time_lookup | |
} | |
# Convert chat history | |
def convert_history_to_messages(history): | |
messages = [] | |
for user, bot in history: | |
messages.append({"role": "user", "content": user}) | |
messages.append({"role": "assistant", "content": bot}) | |
return messages | |
# Chatbot logic | |
def sysml_chatbot(message, history): | |
chat_messages = convert_history_to_messages(history) | |
full_messages = [ | |
{"role": "system", "content": "You are a helpful SysML modeling assistant and also a capable smart Assistant"} | |
] + chat_messages + [{"role": "user", "content": message}] | |
try: | |
response = client.chat.completions.create( | |
model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
messages=full_messages, | |
tools=tools_definition, | |
tool_choice={"type": "function", "function": {"name": "SysMLRetriever"}} | |
) | |
assistant_message = response.choices[0].message | |
if assistant_message.tool_calls: | |
tool_call = assistant_message.tool_calls[0] | |
function_name = tool_call.function.name | |
function_args = json.loads(tool_call.function.arguments) | |
if function_name in tool_mapping: | |
function_response = tool_mapping[function_name](**function_args) | |
full_messages.append({ | |
"role": "assistant", | |
"content": None, | |
"tool_calls": [{ | |
"id": tool_call.id, | |
"type": "function", | |
"function": { | |
"name": function_name, | |
"arguments": tool_call.function.arguments | |
} | |
}] | |
}) | |
full_messages.append({ | |
"role": "tool", | |
"tool_call_id": tool_call.id, | |
"content": function_response | |
}) | |
second_response = client.chat.completions.create( | |
model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
messages=full_messages | |
) | |
answer = second_response.choices[0].message.content | |
else: | |
answer = f"I tried to use a function '{function_name}' that's not available." | |
else: | |
answer = assistant_message.content | |
history.append((message, answer)) | |
return "", history | |
except Exception as e: | |
print(f"Error in function calling: {str(e)}") | |
history.append((message, "Sorry, something went wrong.")) | |
return "", history | |
# === Gradio UI === | |
with gr.Blocks(css=""" | |
#submit-btn { | |
height: 100%; | |
background-color: #48CAE4; | |
color: white; | |
font-size: 1.5em; | |
} | |
""") as demo: | |
gr.Markdown("## SysModeler Chatbot") | |
chatbot = gr.Chatbot(height=600) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
msg = gr.Textbox( | |
placeholder="Ask me about SysML diagrams or concepts...", | |
lines=3, | |
show_label=False | |
) | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button("➤", elem_id="submit-btn") | |
clear = gr.Button("Clear") | |
state = gr.State([]) | |
submit_btn.click(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) | |
msg.submit(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) | |
clear.click(fn=lambda: ([], ""), inputs=None, outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() | |