import streamlit as st import pandas as pd import requests import json import time from typing import Iterator, Dict, Any from apis import stream_ask, run_sql from utils import format_sql def stream_ask_api(api_key: str, project_id: str, question: str, thread_id: str = "", language: str = "English", sample_size: int = 1000) -> Iterator[Dict[str, Any]]: """Stream ask endpoint with proper SSE handling.""" try: response, error = stream_ask(api_key, project_id, question, thread_id, language, sample_size) if error: raise Exception(error) for line in response.iter_lines(): if line: line_str = line.decode('utf-8') if line_str.startswith('data: '): data_str = line_str[6:] # Remove 'data: ' prefix if data_str.strip() == '[DONE]': break try: data = json.loads(data_str) yield data except json.JSONDecodeError: continue except requests.exceptions.RequestException as e: yield {"error": str(e)} def main(): st.title("๐Ÿ’ฌ Wren AI Cloud API Demo - Ask Streaming") if "api_key" not in st.session_state or "project_id" not in st.session_state: st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") return if not st.session_state.api_key or not st.session_state.project_id: st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") return api_key = st.session_state.api_key project_id = st.session_state.project_id st.markdown('Using APIs: [Stream Ask](https://wrenai.readme.io/reference/post_stream-ask-1)') # Sidebar for API configuration with st.sidebar: st.header("๐Ÿ”ง Configuration") language = st.text_input( "Language", "English", help="Language of the response. Default is English." ) sample_size = st.slider( "Sample Size", min_value=100, max_value=10000, value=1000, step=100, help="Number of data points to include in results" ) # Initialize chat history if "streaming_messages" not in st.session_state: st.session_state.streaming_messages = [] if "streaming_thread_id" not in st.session_state: st.session_state.streaming_thread_id = "" # Display chat history for message in st.session_state.streaming_messages: with st.chat_message(message["role"]): if message["role"] == "user": st.text(message["content"]) else: st.text(message["content"]) if "sql" in message: with st.expander("= Generated SQL Query", expanded=False): st.code(format_sql(message["sql"]), language="sql") # Add button to run SQL if st.button("๐Ÿ”„ Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"): with st.spinner("Executing SQL query..."): sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.streaming_thread_id, sample_size) if sql_result: data = sql_result.get("records", []) if data: # Convert to DataFrame for better display df = pd.DataFrame(data) st.success("SQL query executed successfully!") st.dataframe(df, use_container_width=True) else: st.info("Query executed but returned no data.") else: st.error(f"Error executing SQL: {error}") # Chat input if prompt := st.chat_input("Ask a question about your data..."): # Add user message to chat history st.session_state.streaming_messages.append({"role": "user", "content": prompt}) # Display user message with st.chat_message("user"): st.write(prompt) # Generate response using stream ask API with st.chat_message("assistant"): # Create containers for streaming updates status_container = st.empty() progress_container = st.empty() content_container = st.empty() final_answer = "" final_sql = "" current_state = "" # Stream the response try: for event in stream_ask_api(api_key, project_id, prompt, st.session_state.streaming_thread_id, sample_size=sample_size, language=language): if "error" in event: st.error(f"Error: {event['error']}") break # Handle different event types if event.get("type") == "state": current_state = event.get("data", {}).get("state", "") rephrased_question = event.get("data", {}).get("rephrasedQuestion", "") intent_reasoning = event.get("data", {}).get("intentReasoning", "") retrieved_tables = event.get("data", {}).get("retrievedTables", "") sql_generation_reasoning = event.get("data", {}).get("sqlGenerationReasoning", "") current_state_message = f"๐Ÿ’ฌ {current_state}\n\n" if rephrased_question: current_state_message += f"\nRerephrased Question: \n{rephrased_question}\n" if intent_reasoning: current_state_message += f"\nIntent Reasoning: \n{intent_reasoning}\n" if retrieved_tables: current_state_message += f"\nRetrieved Tables: \n{retrieved_tables}\n" if sql_generation_reasoning: current_state_message += f"\nSQL Generation Reasoning: \n{sql_generation_reasoning}\n" status_container.info(current_state_message) elif event.get("type") == "content_block_delta": delta = event.get("delta", {}) if delta.get("type") == "text_delta": final_answer += delta.get("text", "") content_container.text(final_answer) elif event.get("type") == "message_stop": # Extract final data from the event if "summary" in event: final_answer = event["summary"] if "sql" in event: final_sql = event["sql"] if "threadId" in event: st.session_state.streaming_thread_id = event["threadId"] break # Small delay to make streaming visible time.sleep(0.1) # Clear status and show final result status_container.empty() progress_container.empty() if final_answer: st.toast("Answer generated successfully!", icon="๐ŸŽ‰") # Create unique message ID message_id = len(st.session_state.streaming_messages) # Store the response assistant_message = { "role": "assistant", "content": final_answer, "message_id": message_id } if final_sql: assistant_message["sql"] = final_sql st.session_state.streaming_messages.append(assistant_message) content_container.text(final_answer) # Display SQL query if available if final_sql: with st.expander("=Generated SQL Query", expanded=False): st.code(format_sql(final_sql), language="sql") # Add button to run SQL if st.button("๐Ÿ”„ Run SQL Query", key=f"run_sql_{message_id}"): with st.spinner("Executing SQL query..."): sql_result, error = run_sql(api_key, project_id, final_sql, st.session_state.streaming_thread_id, sample_size) if sql_result: data = sql_result.get("records", []) if data: # Convert to DataFrame for better display df = pd.DataFrame(data) st.success("SQL query executed successfully!") st.dataframe(df, use_container_width=True) else: st.info("Query executed but returned no data.") else: st.error(f"Error executing SQL: {error}") else: st.toast("No answer was generated. Please try rephrasing your question.", icon="๐Ÿค”") assistant_message = { "role": "assistant", "content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data." } st.session_state.streaming_messages.append(assistant_message) content_container.text(assistant_message["content"]) except Exception as e: st.error(f"Error during streaming: {str(e)}") assistant_message = { "role": "assistant", "content": "Sorry, I couldn't process your request. Please check your API credentials and try again." } st.session_state.streaming_messages.append(assistant_message) content_container.text(assistant_message["content"]) # Clear chat button if st.sidebar.button("๐Ÿงน Clear Chat History"): st.session_state.streaming_messages = [] st.session_state.streaming_thread_id = "" st.rerun() if __name__ == "__main__": main()