|
import streamlit as st |
|
import pandas as pd |
|
|
|
from apis import ask, run_sql |
|
from utils import format_sql |
|
|
|
|
|
def main(): |
|
st.title("π¬ Wren AI Cloud API Demo - Query and Answer") |
|
|
|
if "api_key" not in st.session_state or "project_id" not in st.session_state: |
|
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") |
|
return |
|
if not st.session_state.api_key or not st.session_state.project_id: |
|
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") |
|
return |
|
|
|
api_key = st.session_state.api_key |
|
project_id = st.session_state.project_id |
|
|
|
st.markdown('Using APIs: [Ask](https://wrenai.readme.io/reference/post_ask-1), [Run SQL](https://wrenai.readme.io/reference/cloud_post_run-sql)') |
|
|
|
|
|
with st.sidebar: |
|
st.header("π§ Configuration") |
|
|
|
sample_size = st.slider( |
|
"Sample Size", |
|
min_value=100, |
|
max_value=10000, |
|
value=1000, |
|
step=100, |
|
help="Number of data points to include in results" |
|
) |
|
|
|
|
|
if "qa_messages" not in st.session_state: |
|
st.session_state.qa_messages = [] |
|
if "qa_thread_id" not in st.session_state: |
|
st.session_state.qa_thread_id = "" |
|
|
|
|
|
for message in st.session_state.qa_messages: |
|
with st.chat_message(message["role"]): |
|
if message["role"] == "user": |
|
st.write(message["content"]) |
|
else: |
|
st.write(message["content"]) |
|
if "sql" in message: |
|
with st.expander("π Generated SQL Query", expanded=False): |
|
st.code(format_sql(message["sql"]), language="sql") |
|
|
|
|
|
if st.button("π Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"): |
|
with st.spinner("Executing SQL query..."): |
|
sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.qa_thread_id, sample_size) |
|
|
|
if sql_result: |
|
data = sql_result.get("records", []) |
|
if data: |
|
|
|
df = pd.DataFrame(data) |
|
st.success("SQL query executed successfully!") |
|
st.dataframe(df, use_container_width=True) |
|
else: |
|
st.info("Query executed but returned no data.") |
|
else: |
|
st.error(f"Error executing SQL: {error}") |
|
|
|
if "sql_results" in message: |
|
st.subheader("π Query Results") |
|
if message["sql_results"]: |
|
st.dataframe(message["sql_results"], use_container_width=True) |
|
else: |
|
st.info("No results returned from the query.") |
|
|
|
|
|
if prompt := st.chat_input("Ask a question about your data..."): |
|
|
|
st.session_state.qa_messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Generating answer..."): |
|
ask_response, error = ask(api_key, project_id, prompt, st.session_state.qa_thread_id, sample_size=sample_size) |
|
|
|
if ask_response: |
|
answer = ask_response.get("summary", "") |
|
sql_query = ask_response.get("sql", "") |
|
st.session_state.qa_thread_id = ask_response.get("threadId", "") |
|
|
|
if answer: |
|
st.toast("Answer generated successfully!", icon="π") |
|
|
|
|
|
message_id = len(st.session_state.qa_messages) |
|
|
|
|
|
assistant_message = { |
|
"role": "assistant", |
|
"content": answer, |
|
"message_id": message_id |
|
} |
|
|
|
if sql_query: |
|
assistant_message["sql"] = sql_query |
|
|
|
st.session_state.qa_messages.append(assistant_message) |
|
st.write(answer) |
|
|
|
|
|
if sql_query: |
|
with st.expander("π Generated SQL Query", expanded=False): |
|
st.code(format_sql(sql_query), language="sql") |
|
|
|
|
|
if st.button("π Run SQL Query", key=f"run_sql_{message_id}"): |
|
with st.spinner("Executing SQL query..."): |
|
sql_result, error = run_sql(api_key, project_id, sql_query, st.session_state.qa_thread_id, sample_size) |
|
|
|
if sql_result: |
|
data = sql_result.get("records", []) |
|
if data: |
|
|
|
df = pd.DataFrame(data) |
|
st.success("SQL query executed successfully!") |
|
st.dataframe(df, use_container_width=True) |
|
else: |
|
st.info("Query executed but returned no data.") |
|
else: |
|
st.error(f"Error executing SQL: {error}") |
|
else: |
|
st.toast("No answer was generated. Please try rephrasing your question.", icon="π€") |
|
assistant_message = { |
|
"role": "assistant", |
|
"content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data." |
|
} |
|
st.session_state.qa_messages.append(assistant_message) |
|
else: |
|
st.toast(f"Error generating answer: {error}", icon="π€") |
|
assistant_message = { |
|
"role": "assistant", |
|
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again." |
|
} |
|
st.session_state.qa_messages.append(assistant_message) |
|
|
|
|
|
if st.sidebar.button("π§Ή Clear Chat History"): |
|
st.session_state.qa_messages = [] |
|
st.session_state.qa_thread_id = "" |
|
st.rerun() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |