Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
from apis import ask, run_sql | |
from utils import format_sql | |
def main(): | |
st.title("π¬ Wren AI Cloud API Demo - SQL Generation") | |
if "api_key" not in st.session_state or "project_id" not in st.session_state: | |
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") | |
return | |
if not st.session_state.api_key or not st.session_state.project_id: | |
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") | |
return | |
api_key = st.session_state.api_key | |
project_id = st.session_state.project_id | |
st.markdown('Using APIs: [Ask](https://wrenai.readme.io/reference/post_ask-1), [Run SQL](https://wrenai.readme.io/reference/cloud_post_run-sql)') | |
# Sidebar for API configuration | |
with st.sidebar: | |
st.header("π§ Configuration") | |
language = st.text_input( | |
"Language", | |
"English", | |
help="Language of the response. Default is English." | |
) | |
sample_size = st.slider( | |
"Sample Size", | |
min_value=100, | |
max_value=10000, | |
value=1000, | |
step=100, | |
help="Number of data points to include in results" | |
) | |
# Initialize chat history | |
if "qa_messages" not in st.session_state: | |
st.session_state.qa_messages = [] | |
if "qa_thread_id" not in st.session_state: | |
st.session_state.qa_thread_id = "" | |
# Display chat history | |
for message in st.session_state.qa_messages: | |
with st.chat_message(message["role"]): | |
if message["role"] == "user": | |
st.write(message["content"]) | |
else: | |
st.write(message["content"]) | |
if "sql" in message: | |
with st.expander("π Generated SQL Query", expanded=False): | |
st.code(format_sql(message["sql"]), language="sql") | |
# Add button to run SQL | |
if st.button("π Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"): | |
with st.spinner("Executing SQL query..."): | |
sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.qa_thread_id, sample_size) | |
if sql_result: | |
data = sql_result.get("records", []) | |
if data: | |
# Convert to DataFrame for better display | |
df = pd.DataFrame(data) | |
st.success("SQL query executed successfully!") | |
st.dataframe(df, use_container_width=True) | |
else: | |
st.info("Query executed but returned no data.") | |
else: | |
st.error(f"Error executing SQL: {error}") | |
if "sql_results" in message: | |
st.subheader("π Query Results") | |
if message["sql_results"]: | |
st.dataframe(message["sql_results"], use_container_width=True) | |
else: | |
st.info("No results returned from the query.") | |
# Chat input | |
if prompt := st.chat_input("Ask a question about your data..."): | |
# Add user message to chat history | |
st.session_state.qa_messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate response using ask API | |
with st.chat_message("assistant"): | |
with st.spinner("Generating answer..."): | |
ask_response, error = ask(api_key, project_id, prompt, st.session_state.qa_thread_id, sample_size=sample_size, language=language) | |
if ask_response: | |
answer = ask_response.get("summary", "") | |
sql_query = ask_response.get("sql", "") | |
st.session_state.qa_thread_id = ask_response.get("threadId", "") | |
if answer: | |
st.toast("Answer generated successfully!", icon="π") | |
# Create unique message ID | |
message_id = len(st.session_state.qa_messages) | |
# Store the response | |
assistant_message = { | |
"role": "assistant", | |
"content": answer, | |
"message_id": message_id | |
} | |
if sql_query: | |
assistant_message["sql"] = sql_query | |
st.session_state.qa_messages.append(assistant_message) | |
st.write(answer) | |
# Display SQL query if available | |
if sql_query: | |
with st.expander("π Generated SQL Query", expanded=False): | |
st.code(format_sql(sql_query), language="sql") | |
# Add button to run SQL | |
if st.button("π Run SQL Query", key=f"run_sql_{message_id}"): | |
with st.spinner("Executing SQL query..."): | |
sql_result, error = run_sql(api_key, project_id, sql_query, st.session_state.qa_thread_id, sample_size) | |
if sql_result: | |
data = sql_result.get("records", []) | |
if data: | |
# Convert to DataFrame for better display | |
df = pd.DataFrame(data) | |
st.success("SQL query executed successfully!") | |
st.dataframe(df, use_container_width=True) | |
else: | |
st.info("Query executed but returned no data.") | |
else: | |
st.error(f"Error executing SQL: {error}") | |
else: | |
st.toast("No answer was generated. Please try rephrasing your question.", icon="π€") | |
assistant_message = { | |
"role": "assistant", | |
"content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data." | |
} | |
st.session_state.qa_messages.append(assistant_message) | |
else: | |
st.toast(f"Error generating answer: {error}", icon="π€") | |
assistant_message = { | |
"role": "assistant", | |
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again." | |
} | |
st.session_state.qa_messages.append(assistant_message) | |
# Clear chat button | |
if st.sidebar.button("π§Ή Clear Chat History"): | |
st.session_state.qa_messages = [] | |
st.session_state.qa_thread_id = "" | |
st.rerun() | |
if __name__ == "__main__": | |
main() |