cyyeh commited on
Commit
e9af7e9
Β·
1 Parent(s): 8896669

update code

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -3
  2. src/streamlit_app.py +218 -38
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
- altair
2
- pandas
3
- streamlit
 
1
+ streamlit
2
+ requests
 
src/streamlit_app.py CHANGED
@@ -1,40 +1,220 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import requests
3
 
4
+ # Page configuration
5
+ st.set_page_config(
6
+ page_title="Wren AI Cloud API Demo",
7
+ page_icon="πŸ“Š",
8
+ layout="wide"
9
+ )
10
+
11
+ def generate_sql(
12
+ api_key: str,
13
+ project_id: str,
14
+ query: str,
15
+ thread_id: str = "",
16
+ ) -> dict:
17
+ """Generate SQL from natural language query."""
18
+ base_url = "https://cloud.getwren.ai/api/v1"
19
+ endpoint = f"{base_url}/generate_sql"
20
+ headers = {
21
+ "Authorization": f"Bearer {api_key}",
22
+ "Content-Type": "application/json"
23
+ }
24
+ payload = {
25
+ "projectId": project_id,
26
+ "question": query,
27
+ }
28
+ if thread_id:
29
+ payload["threadId"] = thread_id
30
+
31
+ try:
32
+ response = requests.post(endpoint, json=payload, headers=headers)
33
+ response.raise_for_status()
34
+ return response.json()
35
+ except requests.exceptions.RequestException as e:
36
+ st.toast(f"Error generating SQL: {e}", icon="🚨")
37
+ return {}
38
+
39
+ def generate_chart(
40
+ api_key: str,
41
+ project_id: str,
42
+ question: str,
43
+ sql: str,
44
+ sample_size: int = 1000,
45
+ thread_id: str = ""
46
+ ) -> dict:
47
+ """Generate a chart from query results."""
48
+ base_url = "https://cloud.getwren.ai/api/v1"
49
+ endpoint = f"{base_url}/generate_vega_chart"
50
+ headers = {
51
+ "Authorization": f"Bearer {api_key}",
52
+ "Content-Type": "application/json"
53
+ }
54
+ payload = {
55
+ "projectId": project_id,
56
+ "question": question,
57
+ "sql": sql,
58
+ "sampleSize": sample_size
59
+ }
60
+ if thread_id:
61
+ payload["threadId"] = thread_id
62
+
63
+ try:
64
+ response = requests.post(endpoint, json=payload, headers=headers)
65
+ response.raise_for_status()
66
+ return response.json()
67
+ except requests.exceptions.RequestException as e:
68
+ st.toast(f"Error generating chart: {e}", icon="🚨")
69
+ return {}
70
+
71
+ def main():
72
+ st.title("πŸ“Š Wren AI Cloud API Demo")
73
+ st.markdown("Ask questions about your data and get both SQL queries and beautiful charts!")
74
+
75
+ # Sidebar for API configuration
76
+ with st.sidebar:
77
+ st.header("πŸ”§ Configuration")
78
+ api_key = st.text_input(
79
+ "API Key",
80
+ type="password",
81
+ placeholder="sk-your-api-key-here",
82
+ help="Enter your Wren AI Cloud API key"
83
+ )
84
+ project_id = st.text_input(
85
+ "Project ID",
86
+ placeholder="1234",
87
+ help="Enter your Wren AI Cloud project ID"
88
+ )
89
+
90
+ # Sample size configuration
91
+ sample_size = st.slider(
92
+ "Chart Sample Size",
93
+ min_value=100,
94
+ max_value=10000,
95
+ value=1000,
96
+ step=100,
97
+ help="Number of data points to include in charts"
98
+ )
99
+
100
+ # Main chat interface
101
+ if not api_key or not project_id:
102
+ st.warning("⚠️ Please enter your API Key and Project ID in the sidebar to get started.")
103
+ st.info("""
104
+ **How to get started:**
105
+ 1. Enter your Wren AI API Key in the sidebar
106
+ 2. Enter your Project ID
107
+ 3. Ask questions about your data in natural language
108
+ 4. Get SQL queries and interactive charts automatically!
109
+ """)
110
+ return
111
+
112
+ # Initialize chat history
113
+ if "messages" not in st.session_state:
114
+ st.session_state.messages = []
115
+ if "thread_id" not in st.session_state:
116
+ st.session_state.thread_id = ""
117
+
118
+ # Display chat history
119
+ for message in st.session_state.messages:
120
+ with st.chat_message(message["role"]):
121
+ if message["role"] == "user":
122
+ st.write(message["content"])
123
+ else:
124
+ st.write(message["content"])
125
+ if "sql" in message:
126
+ with st.expander("πŸ“ Generated SQL Query", expanded=False):
127
+ st.code(message["sql"], language="sql")
128
+ if "vega_spec" in message:
129
+ try:
130
+ with st.expander("πŸ“Š Chart Specification", expanded=False):
131
+ st.json(message["vega_spec"])
132
+ st.vega_lite_chart(message["vega_spec"])
133
+ except Exception as e:
134
+ st.toast(f"Error rendering chart: {e}", icon="🚨")
135
+
136
+ # Chat input
137
+ if prompt := st.chat_input("Ask a question about your data..."):
138
+ # Add user message to chat history
139
+ st.session_state.messages.append({"role": "user", "content": prompt})
140
+
141
+ # Display user message
142
+ with st.chat_message("user"):
143
+ st.write(prompt)
144
+
145
+ # Generate response
146
+ with st.chat_message("assistant"):
147
+ with st.spinner("Generating SQL query..."):
148
+ sql_response = generate_sql(api_key, project_id, prompt, st.session_state.thread_id)
149
+
150
+ if sql_response:
151
+ sql_query = sql_response.get("sql", "")
152
+ st.session_state.thread_id = sql_response.get("threadId", "")
153
+
154
+ if sql_query:
155
+ st.toast("SQL query generated successfully!", icon="πŸŽ‰")
156
+
157
+ # Store the response
158
+ assistant_message = {
159
+ "role": "assistant",
160
+ "content": f"I've generated a SQL query for your question: '{prompt}'",
161
+ "sql": sql_query
162
+ }
163
+ st.session_state.messages.append(assistant_message)
164
+ st.write(assistant_message["content"])
165
+
166
+ # Display SQL query
167
+ with st.expander("πŸ“ Generated SQL Query", expanded=False):
168
+ st.code(sql_query, language="sql")
169
+
170
+ # Generate chart
171
+ with st.spinner("Generating chart..."):
172
+ chart_response = generate_chart(api_key, project_id, prompt, sql_query, sample_size, st.session_state.thread_id)
173
+
174
+ if chart_response:
175
+ vega_spec = chart_response.get("vegaSpec", {})
176
+ if vega_spec:
177
+ st.toast("Chart generated successfully!", icon="πŸŽ‰")
178
+
179
+ assistant_message = {
180
+ "role": "assistant",
181
+ "content": f"I've generated a Chart for your question: '{prompt}'",
182
+ "vega_spec": vega_spec
183
+ }
184
+ st.session_state.messages.append(assistant_message)
185
+ st.write(assistant_message["content"])
186
+
187
+ # Display chart
188
+ try:
189
+ # Show chart specification in expander
190
+ with st.expander("πŸ“Š Chart Specification", expanded=False):
191
+ st.json(vega_spec)
192
+ st.vega_lite_chart(vega_spec)
193
+ except Exception as e:
194
+ st.toast(f"Error rendering chart: {e}", icon="🚨")
195
+ else:
196
+ st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
197
+ else:
198
+ st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
199
+ else:
200
+ st.toast("No SQL query was generated. Please try rephrasing your question.", icon="🚨")
201
+ assistant_message = {
202
+ "role": "assistant",
203
+ "content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data."
204
+ }
205
+ st.session_state.messages.append(assistant_message)
206
+ else:
207
+ assistant_message = {
208
+ "role": "assistant",
209
+ "content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
210
+ }
211
+ st.session_state.messages.append(assistant_message)
212
+
213
+ # Clear chat button
214
+ if st.sidebar.button("πŸ—‘οΈ Clear Chat History"):
215
+ st.session_state.messages = []
216
+ st.session_state.thread_id = ""
217
+ st.rerun()
218
+
219
+ if __name__ == "__main__":
220
+ main()