|
import streamlit as st |
|
import requests |
|
|
|
|
|
st.set_page_config( |
|
page_title="Wren AI Cloud API Demo", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
def generate_sql( |
|
api_key: str, |
|
project_id: str, |
|
query: str, |
|
thread_id: str = "", |
|
) -> dict: |
|
"""Generate SQL from natural language query.""" |
|
base_url = "https://cloud.getwren.ai/api/v1" |
|
endpoint = f"{base_url}/generate_sql" |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
payload = { |
|
"projectId": project_id, |
|
"question": query, |
|
} |
|
if thread_id: |
|
payload["threadId"] = thread_id |
|
|
|
try: |
|
response = requests.post(endpoint, json=payload, headers=headers) |
|
response.raise_for_status() |
|
return response.json() |
|
except requests.exceptions.RequestException as e: |
|
st.toast(f"Error generating SQL: {e}", icon="π¨") |
|
return {} |
|
|
|
def generate_chart( |
|
api_key: str, |
|
project_id: str, |
|
question: str, |
|
sql: str, |
|
sample_size: int = 1000, |
|
thread_id: str = "" |
|
) -> dict: |
|
"""Generate a chart from query results.""" |
|
base_url = "https://cloud.getwren.ai/api/v1" |
|
endpoint = f"{base_url}/generate_vega_chart" |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
payload = { |
|
"projectId": project_id, |
|
"question": question, |
|
"sql": sql, |
|
"sampleSize": sample_size |
|
} |
|
if thread_id: |
|
payload["threadId"] = thread_id |
|
|
|
try: |
|
response = requests.post(endpoint, json=payload, headers=headers) |
|
response.raise_for_status() |
|
return response.json() |
|
except requests.exceptions.RequestException as e: |
|
st.toast(f"Error generating chart: {e}", icon="π¨") |
|
return {} |
|
|
|
def main(): |
|
st.title("π Wren AI Cloud API Demo") |
|
st.markdown("Ask questions about your data and get both SQL queries and beautiful charts!") |
|
|
|
|
|
with st.sidebar: |
|
st.header("π§ Configuration") |
|
api_key = st.text_input( |
|
"API Key", |
|
type="password", |
|
placeholder="sk-your-api-key-here", |
|
help="Enter your Wren AI Cloud API key" |
|
) |
|
project_id = st.text_input( |
|
"Project ID", |
|
placeholder="1234", |
|
help="Enter your Wren AI Cloud project ID" |
|
) |
|
|
|
|
|
sample_size = st.slider( |
|
"Chart Sample Size", |
|
min_value=100, |
|
max_value=10000, |
|
value=1000, |
|
step=100, |
|
help="Number of data points to include in charts" |
|
) |
|
|
|
|
|
if not api_key or not project_id: |
|
st.warning("β οΈ Please enter your API Key and Project ID in the sidebar to get started.") |
|
st.info(""" |
|
**How to get started:** |
|
1. Enter your Wren AI API Key in the sidebar |
|
2. Enter your Project ID |
|
3. Ask questions about your data in natural language |
|
4. Get SQL queries and interactive charts automatically! |
|
""") |
|
return |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "thread_id" not in st.session_state: |
|
st.session_state.thread_id = "" |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
if message["role"] == "user": |
|
st.write(message["content"]) |
|
else: |
|
st.write(message["content"]) |
|
if "sql" in message: |
|
with st.expander("π Generated SQL Query", expanded=False): |
|
st.code(message["sql"], language="sql") |
|
if "vega_spec" in message: |
|
try: |
|
with st.expander("π Chart Specification", expanded=False): |
|
st.json(message["vega_spec"]) |
|
st.vega_lite_chart(message["vega_spec"]) |
|
except Exception as e: |
|
st.toast(f"Error rendering chart: {e}", icon="π¨") |
|
|
|
|
|
if prompt := st.chat_input("Ask a question about your data..."): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Generating SQL query..."): |
|
sql_response = generate_sql(api_key, project_id, prompt, st.session_state.thread_id) |
|
|
|
if sql_response: |
|
sql_query = sql_response.get("sql", "") |
|
st.session_state.thread_id = sql_response.get("threadId", "") |
|
|
|
if sql_query: |
|
st.toast("SQL query generated successfully!", icon="π") |
|
|
|
|
|
assistant_message = { |
|
"role": "assistant", |
|
"content": f"I've generated a SQL query for your question: '{prompt}'", |
|
"sql": sql_query |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
st.write(assistant_message["content"]) |
|
|
|
|
|
with st.expander("π Generated SQL Query", expanded=False): |
|
st.code(sql_query, language="sql") |
|
|
|
|
|
with st.spinner("Generating chart..."): |
|
chart_response = generate_chart(api_key, project_id, prompt, sql_query, sample_size, st.session_state.thread_id) |
|
|
|
if chart_response: |
|
vega_spec = chart_response.get("vegaSpec", {}) |
|
if vega_spec: |
|
st.toast("Chart generated successfully!", icon="π") |
|
|
|
assistant_message = { |
|
"role": "assistant", |
|
"content": f"I've generated a Chart for your question: '{prompt}'", |
|
"vega_spec": vega_spec |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
st.write(assistant_message["content"]) |
|
|
|
|
|
try: |
|
|
|
with st.expander("π Chart Specification", expanded=False): |
|
st.json(vega_spec) |
|
st.vega_lite_chart(vega_spec) |
|
except Exception as e: |
|
st.toast(f"Error rendering chart: {e}", icon="π¨") |
|
else: |
|
st.toast("Failed to generate chart. Please check your query and try again.", icon="π¨") |
|
else: |
|
st.toast("Failed to generate chart. Please check your query and try again.", icon="π¨") |
|
else: |
|
st.toast("No SQL query was generated. Please try rephrasing your question.", icon="π¨") |
|
assistant_message = { |
|
"role": "assistant", |
|
"content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data." |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
else: |
|
assistant_message = { |
|
"role": "assistant", |
|
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again." |
|
} |
|
st.session_state.messages.append(assistant_message) |
|
|
|
|
|
if st.sidebar.button("ποΈ Clear Chat History"): |
|
st.session_state.messages = [] |
|
st.session_state.thread_id = "" |
|
st.rerun() |
|
|
|
if __name__ == "__main__": |
|
main() |