cyyeh commited on
Commit
a1951e5
Β·
1 Parent(s): 2d0b8a5

add new page and refine sql query ui

Browse files
poetry.lock CHANGED
@@ -974,6 +974,21 @@ files = [
974
  {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
975
  ]
976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
977
  [[package]]
978
  name = "streamlit"
979
  version = "1.46.1"
@@ -1139,4 +1154,4 @@ watchmedo = ["PyYAML (>=3.10)"]
1139
  [metadata]
1140
  lock-version = "2.0"
1141
  python-versions = ">=3.12,<3.13"
1142
- content-hash = "d3018cb9ea02785fe38a30f0f0e1196ad02058106619e7ba3a2a11c4a78753a2"
 
974
  {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
975
  ]
976
 
977
+ [[package]]
978
+ name = "sqlparse"
979
+ version = "0.5.3"
980
+ description = "A non-validating SQL parser."
981
+ optional = false
982
+ python-versions = ">=3.8"
983
+ files = [
984
+ {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"},
985
+ {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"},
986
+ ]
987
+
988
+ [package.extras]
989
+ dev = ["build", "hatch"]
990
+ doc = ["sphinx"]
991
+
992
  [[package]]
993
  name = "streamlit"
994
  version = "1.46.1"
 
1154
  [metadata]
1155
  lock-version = "2.0"
1156
  python-versions = ">=3.12,<3.13"
1157
+ content-hash = "67b448796799eb25e83725fc27c23ff400273d48c767151c6462d8e1545052fe"
pyproject.toml CHANGED
@@ -12,6 +12,7 @@ python = ">=3.12,<3.13"
12
  streamlit = "^1.46.1"
13
  requests = "^2.32.4"
14
  watchdog = "^6.0.0"
 
15
 
16
  [tool.poetry.group.dev.dependencies]
17
  python-dotenv = "^1.1.1"
 
12
  streamlit = "^1.46.1"
13
  requests = "^2.32.4"
14
  watchdog = "^6.0.0"
15
+ sqlparse = "^0.5.3"
16
 
17
  [tool.poetry.group.dev.dependencies]
18
  python-dotenv = "^1.1.1"
src/pages/1_Chart_Generation.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
 
3
  from apis import generate_sql, generate_chart
 
4
 
5
 
6
  def main():
@@ -46,7 +47,7 @@ def main():
46
  st.write(message["content"])
47
  if "sql" in message:
48
  with st.expander("πŸ“ Generated SQL Query", expanded=False):
49
- st.code(message["sql"], language="sql")
50
  if "vega_spec" in message:
51
  try:
52
  with st.expander("πŸ“Š Chart Specification", expanded=False):
@@ -87,7 +88,7 @@ def main():
87
 
88
  # Display SQL query
89
  with st.expander("πŸ“ Generated SQL Query", expanded=False):
90
- st.code(sql_query, language="sql")
91
 
92
  # Generate chart
93
  with st.spinner("Generating chart..."):
 
1
  import streamlit as st
2
 
3
  from apis import generate_sql, generate_chart
4
+ from utils import format_sql
5
 
6
 
7
  def main():
 
47
  st.write(message["content"])
48
  if "sql" in message:
49
  with st.expander("πŸ“ Generated SQL Query", expanded=False):
50
+ st.code(format_sql(message["sql"]), language="sql")
51
  if "vega_spec" in message:
52
  try:
53
  with st.expander("πŸ“Š Chart Specification", expanded=False):
 
88
 
89
  # Display SQL query
90
  with st.expander("πŸ“ Generated SQL Query", expanded=False):
91
+ st.code(format_sql(sql_query), language="sql")
92
 
93
  # Generate chart
94
  with st.spinner("Generating chart..."):
src/pages/2_Query_And_Answer.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ from apis import ask, run_sql
5
+ from utils import format_sql
6
+
7
+
8
+ def main():
9
+ st.title("πŸ’¬ Wren AI Cloud API Demo - Query and Answer")
10
+
11
+ if "api_key" not in st.session_state or "project_id" not in st.session_state:
12
+ st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
13
+ return
14
+ if not st.session_state.api_key or not st.session_state.project_id:
15
+ st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
16
+ return
17
+
18
+ api_key = st.session_state.api_key
19
+ project_id = st.session_state.project_id
20
+
21
+ st.markdown('Using APIs: [Ask](https://wrenai.readme.io/reference/post_ask-1), [Run SQL](https://wrenai.readme.io/reference/cloud_post_run-sql)')
22
+
23
+ # Sidebar for API configuration
24
+ with st.sidebar:
25
+ st.header("πŸ”§ Configuration")
26
+ # Sample size configuration
27
+ sample_size = st.slider(
28
+ "Sample Size",
29
+ min_value=100,
30
+ max_value=10000,
31
+ value=1000,
32
+ step=100,
33
+ help="Number of data points to include in results"
34
+ )
35
+
36
+ # Initialize chat history
37
+ if "qa_messages" not in st.session_state:
38
+ st.session_state.qa_messages = []
39
+ if "qa_thread_id" not in st.session_state:
40
+ st.session_state.qa_thread_id = ""
41
+
42
+ # Display chat history
43
+ for message in st.session_state.qa_messages:
44
+ with st.chat_message(message["role"]):
45
+ if message["role"] == "user":
46
+ st.write(message["content"])
47
+ else:
48
+ st.write(message["content"])
49
+ if "sql" in message:
50
+ with st.expander("πŸ” Generated SQL Query", expanded=False):
51
+ st.code(format_sql(message["sql"]), language="sql")
52
+
53
+ # Add button to run SQL
54
+ if st.button("πŸ”„ Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"):
55
+ with st.spinner("Executing SQL query..."):
56
+ sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.qa_thread_id, sample_size)
57
+
58
+ if sql_result:
59
+ data = sql_result.get("records", [])
60
+ if data:
61
+ # Convert to DataFrame for better display
62
+ df = pd.DataFrame(data)
63
+ st.success("SQL query executed successfully!")
64
+ st.dataframe(df, use_container_width=True)
65
+ else:
66
+ st.info("Query executed but returned no data.")
67
+ else:
68
+ st.error(f"Error executing SQL: {error}")
69
+
70
+ if "sql_results" in message:
71
+ st.subheader("πŸ” Query Results")
72
+ if message["sql_results"]:
73
+ st.dataframe(message["sql_results"], use_container_width=True)
74
+ else:
75
+ st.info("No results returned from the query.")
76
+
77
+ # Chat input
78
+ if prompt := st.chat_input("Ask a question about your data..."):
79
+ # Add user message to chat history
80
+ st.session_state.qa_messages.append({"role": "user", "content": prompt})
81
+
82
+ # Display user message
83
+ with st.chat_message("user"):
84
+ st.write(prompt)
85
+
86
+ # Generate response using ask API
87
+ with st.chat_message("assistant"):
88
+ with st.spinner("Generating answer..."):
89
+ ask_response, error = ask(api_key, project_id, prompt, st.session_state.qa_thread_id, sample_size=sample_size)
90
+
91
+ if ask_response:
92
+ answer = ask_response.get("summary", "")
93
+ sql_query = ask_response.get("sql", "")
94
+ st.session_state.qa_thread_id = ask_response.get("threadId", "")
95
+
96
+ if answer:
97
+ st.toast("Answer generated successfully!", icon="πŸŽ‰")
98
+
99
+ # Create unique message ID
100
+ message_id = len(st.session_state.qa_messages)
101
+
102
+ # Store the response
103
+ assistant_message = {
104
+ "role": "assistant",
105
+ "content": answer,
106
+ "message_id": message_id
107
+ }
108
+
109
+ if sql_query:
110
+ assistant_message["sql"] = sql_query
111
+
112
+ st.session_state.qa_messages.append(assistant_message)
113
+ st.write(answer)
114
+
115
+ # Display SQL query if available
116
+ if sql_query:
117
+ with st.expander("πŸ” Generated SQL Query", expanded=False):
118
+ st.code(format_sql(sql_query), language="sql")
119
+
120
+ # Add button to run SQL
121
+ if st.button("πŸ”„ Run SQL Query", key=f"run_sql_{message_id}"):
122
+ with st.spinner("Executing SQL query..."):
123
+ sql_result, error = run_sql(api_key, project_id, sql_query, st.session_state.qa_thread_id, sample_size)
124
+
125
+ if sql_result:
126
+ data = sql_result.get("records", [])
127
+ if data:
128
+ # Convert to DataFrame for better display
129
+ df = pd.DataFrame(data)
130
+ st.success("SQL query executed successfully!")
131
+ st.dataframe(df, use_container_width=True)
132
+ else:
133
+ st.info("Query executed but returned no data.")
134
+ else:
135
+ st.error(f"Error executing SQL: {error}")
136
+ else:
137
+ st.toast("No answer was generated. Please try rephrasing your question.", icon="πŸ€”")
138
+ assistant_message = {
139
+ "role": "assistant",
140
+ "content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data."
141
+ }
142
+ st.session_state.qa_messages.append(assistant_message)
143
+ else:
144
+ st.toast(f"Error generating answer: {error}", icon="πŸ€”")
145
+ assistant_message = {
146
+ "role": "assistant",
147
+ "content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
148
+ }
149
+ st.session_state.qa_messages.append(assistant_message)
150
+
151
+ # Clear chat button
152
+ if st.sidebar.button("🧹 Clear Chat History"):
153
+ st.session_state.qa_messages = []
154
+ st.session_state.qa_thread_id = ""
155
+ st.rerun()
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
src/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import sqlparse
2
+
3
+
4
+ def format_sql(sql: str) -> str:
5
+ return sqlparse.format(
6
+ sql,
7
+ reindent=True,
8
+ keyword_case="upper",
9
+ )