|
import streamlit as st |
|
|
|
from apis import generate_sql, generate_chart |
|
from utils import format_sql |
|
|
|
|
|
def main(): |
|
st.title("π Wren AI Cloud API Demo - Chart Generation") |
|
|
|
if "api_key" not in st.session_state or "project_id" not in st.session_state: |
|
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") |
|
return |
|
if not st.session_state.api_key or not st.session_state.project_id: |
|
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") |
|
return |
|
|
|
api_key = st.session_state.api_key |
|
project_id = st.session_state.project_id |
|
|
|
st.markdown('Using APIs: [SQL Generation](https://wrenai.readme.io/reference/cloud_post_generate-sql), [Chart Generation](https://wrenai.readme.io/reference/cloud_post_generate-vega-chart)') |
|
|
|
|
|
with st.sidebar: |
|
st.header("π§ Configuration") |
|
|
|
sample_size = st.slider( |
|
"Chart Sample Size", |
|
min_value=100, |
|
max_value=10000, |
|
value=1000, |
|
step=100, |
|
help="Number of data points to include in charts" |
|
) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "thread_id" not in st.session_state: |
|
st.session_state.thread_id = "" |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
if message["role"] == "user": |
|
st.write(message["content"]) |
|
else: |
|
st.write(message["content"]) |
|
if "sql" in message: |
|
with st.expander("π Generated SQL Query", expanded=False): |
|
st.code(format_sql(message["sql"]), language="sql") |
|
if "vega_spec" in message: |
|
try: |
|
with st.expander("π Chart Specification", expanded=False): |
|
st.json(message["vega_spec"]) |
|
st.vega_lite_chart(message["vega_spec"]) |
|
except Exception as e: |
|
st.toast(f"Error rendering chart: {e}", icon="π¨") |
|
|
|
|
|
if prompt := st.chat_input("Ask a question about your data..."): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Generating SQL query..."): |
|
sql_response, error = generate_sql(api_key, project_id, prompt, st.session_state.thread_id) |
|
|
|
if sql_response: |
|
sql_query = sql_response.get("sql", "") |
|
st.session_state.thread_id = sql_response.get("threadId", "") |
|
|
|
if sql_query: |
|
st.toast("SQL query generated successfully!", icon="π") |
|
|
|
|
|
assistant_message = { |
|
"role": "assistant", |
|
"content": f"I've generated a SQL query for your question: '{prompt}'", |
|
"sql": sql_query |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
st.write(assistant_message["content"]) |
|
|
|
|
|
with st.expander("π Generated SQL Query", expanded=False): |
|
st.code(format_sql(sql_query), language="sql") |
|
|
|
|
|
with st.spinner("Generating chart..."): |
|
chart_response, error = generate_chart( |
|
api_key, |
|
project_id, |
|
prompt, |
|
sql_query, |
|
thread_id=st.session_state.thread_id, |
|
sample_size=sample_size, |
|
) |
|
|
|
if chart_response: |
|
vega_spec = chart_response.get("vegaSpec", {}) |
|
if vega_spec: |
|
st.toast("Chart generated successfully!", icon="π") |
|
|
|
assistant_message = { |
|
"role": "assistant", |
|
"content": f"I've generated a Chart for your question: '{prompt}'", |
|
"vega_spec": vega_spec |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
st.write(assistant_message["content"]) |
|
|
|
|
|
try: |
|
|
|
with st.expander("π Chart Specification", expanded=False): |
|
st.json(vega_spec) |
|
st.vega_lite_chart(vega_spec) |
|
except Exception as e: |
|
st.toast(f"Error rendering chart: {e}", icon="π¨") |
|
else: |
|
st.toast("Failed to generate chart. Please check your query and try again.", icon="π¨") |
|
else: |
|
st.toast(f"Failed to generate chart. Please check your query and try again.: {error}", icon="π¨") |
|
else: |
|
st.toast("No SQL query was generated. Please try rephrasing your question.", icon="π¨") |
|
assistant_message = { |
|
"role": "assistant", |
|
"content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data." |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
else: |
|
st.toast(f"Error generating SQL: {error}", icon="π¨") |
|
assistant_message = { |
|
"role": "assistant", |
|
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again." |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
|
|
|
|
if st.sidebar.button("ποΈ Clear Chat History"): |
|
st.session_state.messages = [] |
|
st.session_state.thread_id = "" |
|
st.rerun() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |