Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import requests | |
import json | |
import time | |
from typing import Iterator, Dict, Any | |
from apis import stream_ask, run_sql | |
from utils import format_sql | |
def stream_ask_api(api_key: str, project_id: str, question: str, thread_id: str = "", | |
language: str = "English", sample_size: int = 1000) -> Iterator[Dict[str, Any]]: | |
"""Stream ask endpoint with proper SSE handling.""" | |
try: | |
response, error = stream_ask(api_key, project_id, question, thread_id, language, sample_size) | |
if error: | |
raise Exception(error) | |
for line in response.iter_lines(): | |
if line: | |
line_str = line.decode('utf-8') | |
if line_str.startswith('data: '): | |
data_str = line_str[6:] # Remove 'data: ' prefix | |
if data_str.strip() == '[DONE]': | |
break | |
try: | |
data = json.loads(data_str) | |
yield data | |
except json.JSONDecodeError: | |
continue | |
except requests.exceptions.RequestException as e: | |
yield {"error": str(e)} | |
def main(): | |
st.title("π¬ Wren AI Cloud API Demo - Ask Streaming") | |
if "api_key" not in st.session_state or "project_id" not in st.session_state: | |
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") | |
return | |
if not st.session_state.api_key or not st.session_state.project_id: | |
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") | |
return | |
api_key = st.session_state.api_key | |
project_id = st.session_state.project_id | |
st.markdown('Using APIs: [Stream Ask](https://wrenai.readme.io/reference/post_stream-ask-1)') | |
# Sidebar for API configuration | |
with st.sidebar: | |
st.header("π§ Configuration") | |
language = st.text_input( | |
"Language", | |
"English", | |
help="Language of the response. Default is English." | |
) | |
sample_size = st.slider( | |
"Sample Size", | |
min_value=100, | |
max_value=10000, | |
value=1000, | |
step=100, | |
help="Number of data points to include in results" | |
) | |
# Initialize chat history | |
if "streaming_messages" not in st.session_state: | |
st.session_state.streaming_messages = [] | |
if "streaming_thread_id" not in st.session_state: | |
st.session_state.streaming_thread_id = "" | |
# Display chat history | |
for message in st.session_state.streaming_messages: | |
with st.chat_message(message["role"]): | |
if message["role"] == "user": | |
st.text(message["content"]) | |
else: | |
st.text(message["content"]) | |
if "sql" in message: | |
with st.expander("= Generated SQL Query", expanded=False): | |
st.code(format_sql(message["sql"]), language="sql") | |
# Add button to run SQL | |
if st.button("π Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"): | |
with st.spinner("Executing SQL query..."): | |
sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.streaming_thread_id, sample_size) | |
if sql_result: | |
data = sql_result.get("records", []) | |
if data: | |
# Convert to DataFrame for better display | |
df = pd.DataFrame(data) | |
st.success("SQL query executed successfully!") | |
st.dataframe(df, use_container_width=True) | |
else: | |
st.info("Query executed but returned no data.") | |
else: | |
st.error(f"Error executing SQL: {error}") | |
# Chat input | |
if prompt := st.chat_input("Ask a question about your data..."): | |
# Add user message to chat history | |
st.session_state.streaming_messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate response using stream ask API | |
with st.chat_message("assistant"): | |
# Create containers for streaming updates | |
status_container = st.empty() | |
progress_container = st.empty() | |
content_container = st.empty() | |
final_answer = "" | |
final_sql = "" | |
current_state = "" | |
# Stream the response | |
try: | |
for event in stream_ask_api(api_key, project_id, prompt, st.session_state.streaming_thread_id, sample_size=sample_size, language=language): | |
if "error" in event: | |
st.error(f"Error: {event['error']}") | |
break | |
# Handle different event types | |
if event.get("type") == "state": | |
current_state = event.get("data", {}).get("state", "") | |
rephrased_question = event.get("data", {}).get("rephrasedQuestion", "") | |
intent_reasoning = event.get("data", {}).get("intentReasoning", "") | |
retrieved_tables = event.get("data", {}).get("retrievedTables", "") | |
sql_generation_reasoning = event.get("data", {}).get("sqlGenerationReasoning", "") | |
current_state_message = f"π¬ {current_state}\n\n" | |
if rephrased_question: | |
current_state_message += f"\nRerephrased Question: \n{rephrased_question}\n" | |
if intent_reasoning: | |
current_state_message += f"\nIntent Reasoning: \n{intent_reasoning}\n" | |
if retrieved_tables: | |
current_state_message += f"\nRetrieved Tables: \n{retrieved_tables}\n" | |
if sql_generation_reasoning: | |
current_state_message += f"\nSQL Generation Reasoning: \n{sql_generation_reasoning}\n" | |
status_container.info(current_state_message) | |
elif event.get("type") == "content_block_delta": | |
delta = event.get("delta", {}) | |
if delta.get("type") == "text_delta": | |
final_answer += delta.get("text", "") | |
content_container.text(final_answer) | |
elif event.get("type") == "message_stop": | |
# Extract final data from the event | |
if "summary" in event: | |
final_answer = event["summary"] | |
if "sql" in event: | |
final_sql = event["sql"] | |
if "threadId" in event: | |
st.session_state.streaming_thread_id = event["threadId"] | |
break | |
# Small delay to make streaming visible | |
time.sleep(0.1) | |
# Clear status and show final result | |
status_container.empty() | |
progress_container.empty() | |
if final_answer: | |
st.toast("Answer generated successfully!", icon="π") | |
# Create unique message ID | |
message_id = len(st.session_state.streaming_messages) | |
# Store the response | |
assistant_message = { | |
"role": "assistant", | |
"content": final_answer, | |
"message_id": message_id | |
} | |
if final_sql: | |
assistant_message["sql"] = final_sql | |
st.session_state.streaming_messages.append(assistant_message) | |
content_container.text(final_answer) | |
# Display SQL query if available | |
if final_sql: | |
with st.expander("=Generated SQL Query", expanded=False): | |
st.code(format_sql(final_sql), language="sql") | |
# Add button to run SQL | |
if st.button("π Run SQL Query", key=f"run_sql_{message_id}"): | |
with st.spinner("Executing SQL query..."): | |
sql_result, error = run_sql(api_key, project_id, final_sql, st.session_state.streaming_thread_id, sample_size) | |
if sql_result: | |
data = sql_result.get("records", []) | |
if data: | |
# Convert to DataFrame for better display | |
df = pd.DataFrame(data) | |
st.success("SQL query executed successfully!") | |
st.dataframe(df, use_container_width=True) | |
else: | |
st.info("Query executed but returned no data.") | |
else: | |
st.error(f"Error executing SQL: {error}") | |
else: | |
st.toast("No answer was generated. Please try rephrasing your question.", icon="π€") | |
assistant_message = { | |
"role": "assistant", | |
"content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data." | |
} | |
st.session_state.streaming_messages.append(assistant_message) | |
content_container.text(assistant_message["content"]) | |
except Exception as e: | |
st.error(f"Error during streaming: {str(e)}") | |
assistant_message = { | |
"role": "assistant", | |
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again." | |
} | |
st.session_state.streaming_messages.append(assistant_message) | |
content_container.text(assistant_message["content"]) | |
# Clear chat button | |
if st.sidebar.button("π§Ή Clear Chat History"): | |
st.session_state.streaming_messages = [] | |
st.session_state.streaming_thread_id = "" | |
st.rerun() | |
if __name__ == "__main__": | |
main() |