Spaces:
Sleeping
Sleeping
import os | |
import warnings | |
import gradio as gr | |
from dotenv import load_dotenv | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import AzureOpenAIEmbeddings | |
from langchain_community.chat_models import AzureChatOpenAI | |
# Patch Gradio bug | |
import gradio_client.utils | |
gradio_client.utils.json_schema_to_python_type = lambda schema, defs=None: "string" | |
# Load environment variables | |
load_dotenv() | |
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
#AZURE_END_POINT_O3 = os.getenv("AZURE_END_POINT_O3") | |
AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") | |
AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") | |
if not all([AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT,AZURE_END_POINT_O3, AZURE_OPENAI_LLM_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]): | |
raise ValueError("Missing one or more Azure OpenAI environment variables.") | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Initialize Azure embedding model | |
embeddings = AzureOpenAIEmbeddings( | |
azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
#azure_endpoint=AZURE_END_POINT_O3, | |
openai_api_key=AZURE_OPENAI_API_KEY, | |
openai_api_version="2025-01-01-preview", # updated to latest recommended version | |
chunk_size=1000 | |
) | |
# Load FAISS vector store | |
vectorstore = FAISS.load_local( | |
"faiss_index_sysml", embeddings, allow_dangerous_deserialization=True | |
) | |
# Initialize Azure chat model | |
llm = AzureChatOpenAI( | |
deployment_name=AZURE_OPENAI_LLM_DEPLOYMENT, | |
azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
#azure_endpoint=AZURE_END_POINT_O3, | |
openai_api_key=AZURE_OPENAI_API_KEY, | |
openai_api_version="2025-01-01-preview", # updated to latest recommended version | |
temperature=0.5 | |
) | |
# Build conversational RAG chain | |
qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=vectorstore.as_retriever(), | |
return_source_documents=False | |
) | |
history = [] | |
# Chatbot logic | |
def sysml_chatbot(message, history): | |
result = qa({"question": message, "chat_history": history}) | |
answer = result["answer"] | |
history.append((message, answer)) | |
return "", history | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## SysModeler Chatbot") | |
chatbot = gr.Chatbot(height=600) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
msg = gr.Textbox( | |
placeholder="Ask me about SysML diagrams or concepts...", | |
lines=3, | |
show_label=False | |
) | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button("➤") | |
clear = gr.Button("Clear") | |
state = gr.State(history) | |
submit_btn.click(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) | |
msg.submit(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) # still supports enter key | |
clear.click(fn=lambda: ([], ""), inputs=None, outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() | |