Spaces:
Running
Running
add new page and refine sql query ui
Browse files- poetry.lock +16 -1
- pyproject.toml +1 -0
- src/pages/1_Chart_Generation.py +3 -2
- src/pages/2_Query_And_Answer.py +159 -0
- src/utils.py +9 -0
poetry.lock
CHANGED
@@ -974,6 +974,21 @@ files = [
|
|
974 |
{file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
|
975 |
]
|
976 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
977 |
[[package]]
|
978 |
name = "streamlit"
|
979 |
version = "1.46.1"
|
@@ -1139,4 +1154,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|
1139 |
[metadata]
|
1140 |
lock-version = "2.0"
|
1141 |
python-versions = ">=3.12,<3.13"
|
1142 |
-
content-hash = "
|
|
|
974 |
{file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
|
975 |
]
|
976 |
|
977 |
+
[[package]]
|
978 |
+
name = "sqlparse"
|
979 |
+
version = "0.5.3"
|
980 |
+
description = "A non-validating SQL parser."
|
981 |
+
optional = false
|
982 |
+
python-versions = ">=3.8"
|
983 |
+
files = [
|
984 |
+
{file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"},
|
985 |
+
{file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"},
|
986 |
+
]
|
987 |
+
|
988 |
+
[package.extras]
|
989 |
+
dev = ["build", "hatch"]
|
990 |
+
doc = ["sphinx"]
|
991 |
+
|
992 |
[[package]]
|
993 |
name = "streamlit"
|
994 |
version = "1.46.1"
|
|
|
1154 |
[metadata]
|
1155 |
lock-version = "2.0"
|
1156 |
python-versions = ">=3.12,<3.13"
|
1157 |
+
content-hash = "67b448796799eb25e83725fc27c23ff400273d48c767151c6462d8e1545052fe"
|
pyproject.toml
CHANGED
@@ -12,6 +12,7 @@ python = ">=3.12,<3.13"
|
|
12 |
streamlit = "^1.46.1"
|
13 |
requests = "^2.32.4"
|
14 |
watchdog = "^6.0.0"
|
|
|
15 |
|
16 |
[tool.poetry.group.dev.dependencies]
|
17 |
python-dotenv = "^1.1.1"
|
|
|
12 |
streamlit = "^1.46.1"
|
13 |
requests = "^2.32.4"
|
14 |
watchdog = "^6.0.0"
|
15 |
+
sqlparse = "^0.5.3"
|
16 |
|
17 |
[tool.poetry.group.dev.dependencies]
|
18 |
python-dotenv = "^1.1.1"
|
src/pages/1_Chart_Generation.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
from apis import generate_sql, generate_chart
|
|
|
4 |
|
5 |
|
6 |
def main():
|
@@ -46,7 +47,7 @@ def main():
|
|
46 |
st.write(message["content"])
|
47 |
if "sql" in message:
|
48 |
with st.expander("π Generated SQL Query", expanded=False):
|
49 |
-
st.code(message["sql"], language="sql")
|
50 |
if "vega_spec" in message:
|
51 |
try:
|
52 |
with st.expander("π Chart Specification", expanded=False):
|
@@ -87,7 +88,7 @@ def main():
|
|
87 |
|
88 |
# Display SQL query
|
89 |
with st.expander("π Generated SQL Query", expanded=False):
|
90 |
-
st.code(sql_query, language="sql")
|
91 |
|
92 |
# Generate chart
|
93 |
with st.spinner("Generating chart..."):
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
from apis import generate_sql, generate_chart
|
4 |
+
from utils import format_sql
|
5 |
|
6 |
|
7 |
def main():
|
|
|
47 |
st.write(message["content"])
|
48 |
if "sql" in message:
|
49 |
with st.expander("π Generated SQL Query", expanded=False):
|
50 |
+
st.code(format_sql(message["sql"]), language="sql")
|
51 |
if "vega_spec" in message:
|
52 |
try:
|
53 |
with st.expander("π Chart Specification", expanded=False):
|
|
|
88 |
|
89 |
# Display SQL query
|
90 |
with st.expander("π Generated SQL Query", expanded=False):
|
91 |
+
st.code(format_sql(sql_query), language="sql")
|
92 |
|
93 |
# Generate chart
|
94 |
with st.spinner("Generating chart..."):
|
src/pages/2_Query_And_Answer.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from apis import ask, run_sql
|
5 |
+
from utils import format_sql
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
st.title("π¬ Wren AI Cloud API Demo - Query and Answer")
|
10 |
+
|
11 |
+
if "api_key" not in st.session_state or "project_id" not in st.session_state:
|
12 |
+
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
|
13 |
+
return
|
14 |
+
if not st.session_state.api_key or not st.session_state.project_id:
|
15 |
+
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
|
16 |
+
return
|
17 |
+
|
18 |
+
api_key = st.session_state.api_key
|
19 |
+
project_id = st.session_state.project_id
|
20 |
+
|
21 |
+
st.markdown('Using APIs: [Ask](https://wrenai.readme.io/reference/post_ask-1), [Run SQL](https://wrenai.readme.io/reference/cloud_post_run-sql)')
|
22 |
+
|
23 |
+
# Sidebar for API configuration
|
24 |
+
with st.sidebar:
|
25 |
+
st.header("π§ Configuration")
|
26 |
+
# Sample size configuration
|
27 |
+
sample_size = st.slider(
|
28 |
+
"Sample Size",
|
29 |
+
min_value=100,
|
30 |
+
max_value=10000,
|
31 |
+
value=1000,
|
32 |
+
step=100,
|
33 |
+
help="Number of data points to include in results"
|
34 |
+
)
|
35 |
+
|
36 |
+
# Initialize chat history
|
37 |
+
if "qa_messages" not in st.session_state:
|
38 |
+
st.session_state.qa_messages = []
|
39 |
+
if "qa_thread_id" not in st.session_state:
|
40 |
+
st.session_state.qa_thread_id = ""
|
41 |
+
|
42 |
+
# Display chat history
|
43 |
+
for message in st.session_state.qa_messages:
|
44 |
+
with st.chat_message(message["role"]):
|
45 |
+
if message["role"] == "user":
|
46 |
+
st.write(message["content"])
|
47 |
+
else:
|
48 |
+
st.write(message["content"])
|
49 |
+
if "sql" in message:
|
50 |
+
with st.expander("π Generated SQL Query", expanded=False):
|
51 |
+
st.code(format_sql(message["sql"]), language="sql")
|
52 |
+
|
53 |
+
# Add button to run SQL
|
54 |
+
if st.button("π Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"):
|
55 |
+
with st.spinner("Executing SQL query..."):
|
56 |
+
sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.qa_thread_id, sample_size)
|
57 |
+
|
58 |
+
if sql_result:
|
59 |
+
data = sql_result.get("records", [])
|
60 |
+
if data:
|
61 |
+
# Convert to DataFrame for better display
|
62 |
+
df = pd.DataFrame(data)
|
63 |
+
st.success("SQL query executed successfully!")
|
64 |
+
st.dataframe(df, use_container_width=True)
|
65 |
+
else:
|
66 |
+
st.info("Query executed but returned no data.")
|
67 |
+
else:
|
68 |
+
st.error(f"Error executing SQL: {error}")
|
69 |
+
|
70 |
+
if "sql_results" in message:
|
71 |
+
st.subheader("π Query Results")
|
72 |
+
if message["sql_results"]:
|
73 |
+
st.dataframe(message["sql_results"], use_container_width=True)
|
74 |
+
else:
|
75 |
+
st.info("No results returned from the query.")
|
76 |
+
|
77 |
+
# Chat input
|
78 |
+
if prompt := st.chat_input("Ask a question about your data..."):
|
79 |
+
# Add user message to chat history
|
80 |
+
st.session_state.qa_messages.append({"role": "user", "content": prompt})
|
81 |
+
|
82 |
+
# Display user message
|
83 |
+
with st.chat_message("user"):
|
84 |
+
st.write(prompt)
|
85 |
+
|
86 |
+
# Generate response using ask API
|
87 |
+
with st.chat_message("assistant"):
|
88 |
+
with st.spinner("Generating answer..."):
|
89 |
+
ask_response, error = ask(api_key, project_id, prompt, st.session_state.qa_thread_id, sample_size=sample_size)
|
90 |
+
|
91 |
+
if ask_response:
|
92 |
+
answer = ask_response.get("summary", "")
|
93 |
+
sql_query = ask_response.get("sql", "")
|
94 |
+
st.session_state.qa_thread_id = ask_response.get("threadId", "")
|
95 |
+
|
96 |
+
if answer:
|
97 |
+
st.toast("Answer generated successfully!", icon="π")
|
98 |
+
|
99 |
+
# Create unique message ID
|
100 |
+
message_id = len(st.session_state.qa_messages)
|
101 |
+
|
102 |
+
# Store the response
|
103 |
+
assistant_message = {
|
104 |
+
"role": "assistant",
|
105 |
+
"content": answer,
|
106 |
+
"message_id": message_id
|
107 |
+
}
|
108 |
+
|
109 |
+
if sql_query:
|
110 |
+
assistant_message["sql"] = sql_query
|
111 |
+
|
112 |
+
st.session_state.qa_messages.append(assistant_message)
|
113 |
+
st.write(answer)
|
114 |
+
|
115 |
+
# Display SQL query if available
|
116 |
+
if sql_query:
|
117 |
+
with st.expander("π Generated SQL Query", expanded=False):
|
118 |
+
st.code(format_sql(sql_query), language="sql")
|
119 |
+
|
120 |
+
# Add button to run SQL
|
121 |
+
if st.button("π Run SQL Query", key=f"run_sql_{message_id}"):
|
122 |
+
with st.spinner("Executing SQL query..."):
|
123 |
+
sql_result, error = run_sql(api_key, project_id, sql_query, st.session_state.qa_thread_id, sample_size)
|
124 |
+
|
125 |
+
if sql_result:
|
126 |
+
data = sql_result.get("records", [])
|
127 |
+
if data:
|
128 |
+
# Convert to DataFrame for better display
|
129 |
+
df = pd.DataFrame(data)
|
130 |
+
st.success("SQL query executed successfully!")
|
131 |
+
st.dataframe(df, use_container_width=True)
|
132 |
+
else:
|
133 |
+
st.info("Query executed but returned no data.")
|
134 |
+
else:
|
135 |
+
st.error(f"Error executing SQL: {error}")
|
136 |
+
else:
|
137 |
+
st.toast("No answer was generated. Please try rephrasing your question.", icon="π€")
|
138 |
+
assistant_message = {
|
139 |
+
"role": "assistant",
|
140 |
+
"content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data."
|
141 |
+
}
|
142 |
+
st.session_state.qa_messages.append(assistant_message)
|
143 |
+
else:
|
144 |
+
st.toast(f"Error generating answer: {error}", icon="π€")
|
145 |
+
assistant_message = {
|
146 |
+
"role": "assistant",
|
147 |
+
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
|
148 |
+
}
|
149 |
+
st.session_state.qa_messages.append(assistant_message)
|
150 |
+
|
151 |
+
# Clear chat button
|
152 |
+
if st.sidebar.button("π§Ή Clear Chat History"):
|
153 |
+
st.session_state.qa_messages = []
|
154 |
+
st.session_state.qa_thread_id = ""
|
155 |
+
st.rerun()
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
main()
|
src/utils.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlparse
|
2 |
+
|
3 |
+
|
4 |
+
def format_sql(sql: str) -> str:
|
5 |
+
return sqlparse.format(
|
6 |
+
sql,
|
7 |
+
reindent=True,
|
8 |
+
keyword_case="upper",
|
9 |
+
)
|