cyyeh's picture
refactor
b038dbc
raw
history blame
6.91 kB
import streamlit as st
from apis import generate_sql, generate_chart
# Page configuration
st.set_page_config(
page_title="Wren AI Cloud API Demo",
page_icon="πŸ“Š",
layout="wide"
)
def main():
st.title("πŸ“Š Wren AI Cloud API Demo")
st.markdown("Ask questions about your data and get both SQL queries and beautiful charts!")
# Sidebar for API configuration
with st.sidebar:
st.header("πŸ”§ Configuration")
api_key = st.text_input(
"API Key",
type="password",
placeholder="sk-your-api-key-here",
help="Enter your Wren AI Cloud API key"
)
project_id = st.text_input(
"Project ID",
placeholder="1234",
help="Enter your Wren AI Cloud project ID"
)
# Sample size configuration
sample_size = st.slider(
"Chart Sample Size",
min_value=100,
max_value=10000,
value=1000,
step=100,
help="Number of data points to include in charts"
)
# Main chat interface
if not api_key or not project_id:
st.warning("⚠️ Please enter your API Key and Project ID in the sidebar to get started.")
st.info("""
**How to get started:**
1. Enter your Wren AI Cloud API Key in the sidebar
2. Enter your Project ID
3. Ask questions about your data in natural language
4. Get SQL queries and interactive charts automatically!
""")
return
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
if "thread_id" not in st.session_state:
st.session_state.thread_id = ""
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message["role"] == "user":
st.write(message["content"])
else:
st.write(message["content"])
if "sql" in message:
with st.expander("πŸ“ Generated SQL Query", expanded=False):
st.code(message["sql"], language="sql")
if "vega_spec" in message:
try:
with st.expander("πŸ“Š Chart Specification", expanded=False):
st.json(message["vega_spec"])
st.vega_lite_chart(message["vega_spec"])
except Exception as e:
st.toast(f"Error rendering chart: {e}", icon="🚨")
# Chat input
if prompt := st.chat_input("Ask a question about your data..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate response
with st.chat_message("assistant"):
with st.spinner("Generating SQL query..."):
sql_response, error = generate_sql(api_key, project_id, prompt, st.session_state.thread_id)
if sql_response:
sql_query = sql_response.get("sql", "")
st.session_state.thread_id = sql_response.get("threadId", "")
if sql_query:
st.toast("SQL query generated successfully!", icon="πŸŽ‰")
# Store the response
assistant_message = {
"role": "assistant",
"content": f"I've generated a SQL query for your question: '{prompt}'",
"sql": sql_query
}
st.session_state.messages.append(assistant_message)
st.write(assistant_message["content"])
# Display SQL query
with st.expander("πŸ“ Generated SQL Query", expanded=False):
st.code(sql_query, language="sql")
# Generate chart
with st.spinner("Generating chart..."):
chart_response, error = generate_chart(api_key, project_id, prompt, sql_query, sample_size, st.session_state.thread_id)
if chart_response:
vega_spec = chart_response.get("vegaSpec", {})
if vega_spec:
st.toast("Chart generated successfully!", icon="πŸŽ‰")
assistant_message = {
"role": "assistant",
"content": f"I've generated a Chart for your question: '{prompt}'",
"vega_spec": vega_spec
}
st.session_state.messages.append(assistant_message)
st.write(assistant_message["content"])
# Display chart
try:
# Show chart specification in expander
with st.expander("πŸ“Š Chart Specification", expanded=False):
st.json(vega_spec)
st.vega_lite_chart(vega_spec)
except Exception as e:
st.toast(f"Error rendering chart: {e}", icon="🚨")
else:
st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
else:
st.toast(f"Failed to generate chart. Please check your query and try again.: {error}", icon="🚨")
else:
st.toast("No SQL query was generated. Please try rephrasing your question.", icon="🚨")
assistant_message = {
"role": "assistant",
"content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data."
}
st.session_state.messages.append(assistant_message)
else:
st.toast(f"Error generating SQL: {error}", icon="🚨")
assistant_message = {
"role": "assistant",
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
}
st.session_state.messages.append(assistant_message)
# Clear chat button
if st.sidebar.button("πŸ—‘οΈ Clear Chat History"):
st.session_state.messages = []
st.session_state.thread_id = ""
st.rerun()
if __name__ == "__main__":
main()