wrenai-cloud-api-demo / src /pages /1_Ask_Streaming.py
cyyeh's picture
update
99a7fbb
import streamlit as st
import pandas as pd
import requests
import json
import time
from typing import Iterator, Dict, Any
from apis import stream_ask, run_sql
from utils import format_sql
def stream_ask_api(api_key: str, project_id: str, question: str, thread_id: str = "",
language: str = "English", sample_size: int = 1000) -> Iterator[Dict[str, Any]]:
"""Stream ask endpoint with proper SSE handling."""
try:
response, error = stream_ask(api_key, project_id, question, thread_id, language, sample_size)
if error:
raise Exception(error)
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
data_str = line_str[6:] # Remove 'data: ' prefix
if data_str.strip() == '[DONE]':
break
try:
data = json.loads(data_str)
yield data
except json.JSONDecodeError:
continue
except requests.exceptions.RequestException as e:
yield {"error": str(e)}
def main():
st.title("πŸ’¬ Wren AI Cloud API Demo - Ask Streaming")
if "api_key" not in st.session_state or "project_id" not in st.session_state:
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
return
if not st.session_state.api_key or not st.session_state.project_id:
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
return
api_key = st.session_state.api_key
project_id = st.session_state.project_id
st.markdown('Using APIs: [Stream Ask](https://wrenai.readme.io/reference/post_stream-ask-1)')
# Sidebar for API configuration
with st.sidebar:
st.header("πŸ”§ Configuration")
language = st.text_input(
"Language",
"English",
help="Language of the response. Default is English."
)
sample_size = st.slider(
"Sample Size",
min_value=100,
max_value=10000,
value=1000,
step=100,
help="Number of data points to include in results"
)
# Initialize chat history
if "streaming_messages" not in st.session_state:
st.session_state.streaming_messages = []
if "streaming_thread_id" not in st.session_state:
st.session_state.streaming_thread_id = ""
# Display chat history
for message in st.session_state.streaming_messages:
with st.chat_message(message["role"]):
if message["role"] == "user":
st.text(message["content"])
else:
st.text(message["content"])
if "sql" in message:
with st.expander("= Generated SQL Query", expanded=False):
st.code(format_sql(message["sql"]), language="sql")
# Add button to run SQL
if st.button("πŸ”„ Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"):
with st.spinner("Executing SQL query..."):
sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.streaming_thread_id, sample_size)
if sql_result:
data = sql_result.get("records", [])
if data:
# Convert to DataFrame for better display
df = pd.DataFrame(data)
st.success("SQL query executed successfully!")
st.dataframe(df, use_container_width=True)
else:
st.info("Query executed but returned no data.")
else:
st.error(f"Error executing SQL: {error}")
# Chat input
if prompt := st.chat_input("Ask a question about your data..."):
# Add user message to chat history
st.session_state.streaming_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate response using stream ask API
with st.chat_message("assistant"):
# Create containers for streaming updates
status_container = st.empty()
progress_container = st.empty()
content_container = st.empty()
final_answer = ""
final_sql = ""
current_state = ""
# Stream the response
try:
for event in stream_ask_api(api_key, project_id, prompt, st.session_state.streaming_thread_id, sample_size=sample_size, language=language):
if "error" in event:
st.error(f"Error: {event['error']}")
break
# Handle different event types
if event.get("type") == "state":
current_state = event.get("data", {}).get("state", "")
rephrased_question = event.get("data", {}).get("rephrasedQuestion", "")
intent_reasoning = event.get("data", {}).get("intentReasoning", "")
retrieved_tables = event.get("data", {}).get("retrievedTables", "")
sql_generation_reasoning = event.get("data", {}).get("sqlGenerationReasoning", "")
current_state_message = f"πŸ’¬ {current_state}\n\n"
if rephrased_question:
current_state_message += f"\nRerephrased Question: \n{rephrased_question}\n"
if intent_reasoning:
current_state_message += f"\nIntent Reasoning: \n{intent_reasoning}\n"
if retrieved_tables:
current_state_message += f"\nRetrieved Tables: \n{retrieved_tables}\n"
if sql_generation_reasoning:
current_state_message += f"\nSQL Generation Reasoning: \n{sql_generation_reasoning}\n"
status_container.info(current_state_message)
elif event.get("type") == "content_block_delta":
delta = event.get("delta", {})
if delta.get("type") == "text_delta":
final_answer += delta.get("text", "")
content_container.text(final_answer)
elif event.get("type") == "message_stop":
# Extract final data from the event
if "summary" in event:
final_answer = event["summary"]
if "sql" in event:
final_sql = event["sql"]
if "threadId" in event:
st.session_state.streaming_thread_id = event["threadId"]
break
# Small delay to make streaming visible
time.sleep(0.1)
# Clear status and show final result
status_container.empty()
progress_container.empty()
if final_answer:
st.toast("Answer generated successfully!", icon="πŸŽ‰")
# Create unique message ID
message_id = len(st.session_state.streaming_messages)
# Store the response
assistant_message = {
"role": "assistant",
"content": final_answer,
"message_id": message_id
}
if final_sql:
assistant_message["sql"] = final_sql
st.session_state.streaming_messages.append(assistant_message)
content_container.text(final_answer)
# Display SQL query if available
if final_sql:
with st.expander("=Generated SQL Query", expanded=False):
st.code(format_sql(final_sql), language="sql")
# Add button to run SQL
if st.button("πŸ”„ Run SQL Query", key=f"run_sql_{message_id}"):
with st.spinner("Executing SQL query..."):
sql_result, error = run_sql(api_key, project_id, final_sql, st.session_state.streaming_thread_id, sample_size)
if sql_result:
data = sql_result.get("records", [])
if data:
# Convert to DataFrame for better display
df = pd.DataFrame(data)
st.success("SQL query executed successfully!")
st.dataframe(df, use_container_width=True)
else:
st.info("Query executed but returned no data.")
else:
st.error(f"Error executing SQL: {error}")
else:
st.toast("No answer was generated. Please try rephrasing your question.", icon="πŸ€”")
assistant_message = {
"role": "assistant",
"content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data."
}
st.session_state.streaming_messages.append(assistant_message)
content_container.text(assistant_message["content"])
except Exception as e:
st.error(f"Error during streaming: {str(e)}")
assistant_message = {
"role": "assistant",
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
}
st.session_state.streaming_messages.append(assistant_message)
content_container.text(assistant_message["content"])
# Clear chat button
if st.sidebar.button("🧹 Clear Chat History"):
st.session_state.streaming_messages = []
st.session_state.streaming_thread_id = ""
st.rerun()
if __name__ == "__main__":
main()