Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import warnings | |
import json | |
from dotenv import load_dotenv | |
import logging | |
from functools import lru_cache | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import AzureOpenAIEmbeddings | |
from openai import AzureOpenAI | |
# Patch Gradio bug | |
import gradio_client.utils | |
gradio_client.utils.json_schema_to_python_type = lambda schema, defs=None: "string" | |
# Load environment variables | |
load_dotenv() | |
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") | |
AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") | |
if not all([AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_LLM_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]): | |
raise ValueError("Missing one or more Azure OpenAI environment variables.") | |
warnings.filterwarnings("ignore") | |
# Embeddings | |
embeddings = AzureOpenAIEmbeddings( | |
azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
openai_api_key=AZURE_OPENAI_API_KEY, | |
openai_api_version="2025-01-01-preview", | |
chunk_size=1000 | |
) | |
# Vectorstore | |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
FAISS_INDEX_PATH = os.path.join(SCRIPT_DIR, "faiss_index_sysml") | |
vectorstore = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True) | |
# OpenAI client | |
client = AzureOpenAI( | |
api_key=AZURE_OPENAI_API_KEY, | |
api_version="2025-01-01-preview", | |
azure_endpoint=AZURE_OPENAI_ENDPOINT | |
) | |
logger = logging.getLogger(__name__) | |
def clean_em_dashes(text: str) -> str: | |
text = text.replace("—which", ", which") | |
text = text.replace("—that", ", that") | |
text = text.replace("—no", ". No") | |
text = text.replace("—and", ", and") | |
text = text.replace("—but", ", but") | |
text = text.replace("—so", ", so") | |
text = text.replace("—you", ". You") | |
text = text.replace("—it", ". It") | |
text = text.replace("—just", ". Just") | |
text = text.replace("—great", ", great") | |
text = text.replace("—this", ". This") | |
text = text.replace("—", ", ") | |
return text | |
def sysml_retriever(query: str) -> str: | |
try: | |
results = vectorstore.similarity_search_with_score(query, k=100) | |
weighted_results = [] | |
for (doc, score) in results: | |
doc_source = doc.metadata.get('source', '').lower() if hasattr(doc, 'metadata') else str(doc).lower() | |
is_sysmodeler = ( | |
'sysmodeler' in doc_source or | |
'user manual' in doc_source or | |
'sysmodeler.ai' in doc.page_content.lower() or | |
'workspace.sysmodeler.ai' in doc.page_content.lower() or | |
'Create with AI' in doc.page_content or | |
'Canvas Overview' in doc.page_content or | |
'AI-powered' in doc.page_content or | |
'voice input' in doc.page_content or | |
'Canvas interface' in doc.page_content or | |
'Project Creation' in doc.page_content or | |
'Shape Palette' in doc.page_content or | |
'AI Copilot' in doc.page_content or | |
'SynthAgent' in doc.page_content or | |
'workspace dashboard' in doc.page_content.lower() | |
) | |
if is_sysmodeler: | |
weighted_score = score * 0.6 | |
source_type = "SysModeler" | |
else: | |
weighted_score = score | |
source_type = "Other" | |
doc.metadata = doc.metadata if hasattr(doc, 'metadata') else {} | |
doc.metadata['source_type'] = 'sysmodeler' if is_sysmodeler else 'other' | |
doc.metadata['weighted_score'] = weighted_score | |
doc.metadata['original_score'] = score | |
weighted_results.append((doc, weighted_score, source_type)) | |
weighted_results.sort(key=lambda x: x[1]) | |
query_lower = query.lower() | |
is_tool_comparison = any(word in query_lower for word in ['tool', 'compare', 'choose', 'vs', 'versus', 'better']) | |
if is_tool_comparison: | |
sysmodeler_docs = [(doc, score) for doc, score, type_ in weighted_results if type_ == "SysModeler"][:8] | |
other_docs = [(doc, score) for doc, score, type_ in weighted_results if type_ == "Other"][:4] | |
final_docs = [doc for doc, _ in sysmodeler_docs] + [doc for doc, _ in other_docs] | |
else: | |
final_docs = [doc for doc, _, _ in weighted_results[:12]] | |
contexts = [doc.page_content for doc in final_docs] | |
return "\n\n".join(contexts) | |
except Exception as e: | |
logger.error(f"Retrieval error: {str(e)}") | |
return "Unable to retrieve information at this time." | |
tools_definition = [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "SysMLRetriever", | |
"description": "Use this to answer questions about SysML diagrams and modeling.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": {"type": "string", "description": "The search query to find information about SysML"} | |
}, | |
"required": ["query"] | |
} | |
} | |
} | |
] | |
tool_mapping = { | |
"SysMLRetriever": sysml_retriever | |
} | |
def convert_history_to_messages(history): | |
messages = [] | |
for user, bot in history: | |
messages.append({"role": "user", "content": user}) | |
messages.append({"role": "assistant", "content": bot}) | |
return messages | |
def sysml_chatbot(message, history): | |
if not message or not message.strip(): | |
answer = "Can I help you with anything else?" | |
history.append(("", answer)) | |
return "", history | |
chat_messages = convert_history_to_messages(history) | |
full_messages = [ | |
{"role": "system", "content": """You are Abu, SysModeler.ai's friendly and knowledgeable assistant. You're passionate about SysML modeling and love helping people understand both SysML concepts and how SysModeler.ai can make their modeling work easier. | |
CONVERSATION STYLE: | |
- Only introduce yourself as "Hi, I'm Abu!" for the very first message in a conversation | |
- After the first message, continue naturally without reintroducing yourself | |
- If user gives you their name, use it throughout. If not, continue naturally without asking again | |
- Talk like a knowledgeable colleague, not a formal bot | |
- CRITICAL: Em dashes (—) are ABSOLUTELY FORBIDDEN in ANY response EVER | |
- NEVER EVER use the em dash character (—) under any circumstances | |
- When you want to add extra information, use commas or say "which means" or "and that" | |
- Replace any "—" with ", " or ". " or " and " or " which " | |
- Be enthusiastic but not pushy about SysModeler.ai | |
- Ask engaging follow-up questions to keep the conversation going | |
- Use "you" and "your" to make it personal | |
- Share insights like you're having a friendly chat | |
"""} | |
] + chat_messages + [{"role": "user", "content": message}] | |
try: | |
response = client.chat.completions.create( | |
model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
messages=full_messages, | |
tools=tools_definition, | |
tool_choice={"type": "function", "function": {"name": "SysMLRetriever"}} | |
) | |
assistant_message = response.choices[0].message | |
if assistant_message.tool_calls: | |
tool_call = assistant_message.tool_calls[0] | |
function_name = tool_call.function.name | |
function_args = json.loads(tool_call.function.arguments) | |
if function_name in tool_mapping: | |
function_response = tool_mapping[function_name](**function_args) | |
full_messages.append({ | |
"role": "assistant", | |
"content": None, | |
"tool_calls": [{ | |
"id": tool_call.id, | |
"type": "function", | |
"function": { | |
"name": function_name, | |
"arguments": tool_call.function.arguments | |
} | |
}] | |
}) | |
full_messages.append({ | |
"role": "tool", | |
"tool_call_id": tool_call.id, | |
"content": function_response | |
}) | |
second_response = client.chat.completions.create( | |
model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
messages=full_messages | |
) | |
answer = second_response.choices[0].message.content | |
answer = clean_em_dashes(answer) | |
else: | |
answer = f"I tried to use a function '{function_name}' that's not available." | |
else: | |
answer = assistant_message.content | |
answer = clean_em_dashes(answer) if answer else answer | |
history.append((message, answer)) | |
return "", history | |
except Exception as e: | |
history.append((message, "Sorry, something went wrong.")) | |
return "", history | |
# === Gradio UI === | |
with gr.Blocks(css=""" | |
#submit-btn { | |
height: 100%; | |
background-color: #48CAE4; | |
color: white; | |
font-size: 1.5em; | |
} | |
""") as demo: | |
gr.Markdown("## SysModeler Chatbot") | |
chatbot = gr.Chatbot(height=600) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
msg = gr.Textbox( | |
placeholder="Ask me about SysML diagrams or concepts...", | |
lines=3, | |
show_label=False | |
) | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button("➤", elem_id="submit-btn") | |
clear = gr.Button("Clear") | |
state = gr.State([]) | |
submit_btn.click(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) | |
msg.submit(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) | |
clear.click(fn=lambda: ([], ""), inputs=None, outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() | |