cyyeh commited on
Commit
b038dbc
Β·
1 Parent(s): 383166b
Files changed (6) hide show
  1. .gitignore +2 -1
  2. Dockerfile +1 -1
  3. Makefile +1 -1
  4. README.md +1 -1
  5. src/apis.py +61 -0
  6. src/{streamlit_app.py β†’ app.py} +9 -64
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .DS_Store
 
 
1
+ .DS_Store
2
+ __pycache__
Dockerfile CHANGED
@@ -31,4 +31,4 @@ EXPOSE 8501
31
 
32
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
33
 
34
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
31
 
32
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
33
 
34
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
Makefile CHANGED
@@ -1,5 +1,5 @@
1
  run:
2
- poetry run streamlit run src/streamlit_app.py
3
 
4
  deps:
5
  poetry export --without-hashes --format=requirements.txt > requirements.txt
 
1
  run:
2
+ poetry run streamlit run src/app.py
3
 
4
  deps:
5
  poetry export --without-hashes --format=requirements.txt > requirements.txt
README.md CHANGED
@@ -14,7 +14,7 @@ license: mit
14
 
15
  # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
 
19
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
  forums](https://discuss.streamlit.io).
 
14
 
15
  # Welcome to Streamlit!
16
 
17
+ Edit `/src/app.py` to customize this app to your heart's desire. :heart:
18
 
19
  If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
  forums](https://discuss.streamlit.io).
src/apis.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+
4
+ def generate_sql(
5
+ api_key: str,
6
+ project_id: str,
7
+ query: str,
8
+ thread_id: str = "",
9
+ ) -> tuple[dict, str]:
10
+ """Generate SQL from natural language query."""
11
+ base_url = "https://cloud.getwren.ai/api/v1"
12
+ endpoint = f"{base_url}/generate_sql"
13
+ headers = {
14
+ "Authorization": f"Bearer {api_key}",
15
+ "Content-Type": "application/json"
16
+ }
17
+ payload = {
18
+ "projectId": project_id,
19
+ "question": query,
20
+ }
21
+ if thread_id:
22
+ payload["threadId"] = thread_id
23
+
24
+ try:
25
+ response = requests.post(endpoint, json=payload, headers=headers)
26
+ response.raise_for_status()
27
+ return response.json(), ""
28
+ except requests.exceptions.RequestException as e:
29
+ return {}, e
30
+
31
+
32
+ def generate_chart(
33
+ api_key: str,
34
+ project_id: str,
35
+ question: str,
36
+ sql: str,
37
+ sample_size: int = 1000,
38
+ thread_id: str = ""
39
+ ) -> tuple[dict, str]:
40
+ """Generate a chart from query results."""
41
+ base_url = "https://cloud.getwren.ai/api/v1"
42
+ endpoint = f"{base_url}/generate_vega_chart"
43
+ headers = {
44
+ "Authorization": f"Bearer {api_key}",
45
+ "Content-Type": "application/json"
46
+ }
47
+ payload = {
48
+ "projectId": project_id,
49
+ "question": question,
50
+ "sql": sql,
51
+ "sampleSize": sample_size
52
+ }
53
+ if thread_id:
54
+ payload["threadId"] = thread_id
55
+
56
+ try:
57
+ response = requests.post(endpoint, json=payload, headers=headers)
58
+ response.raise_for_status()
59
+ return response.json(), ""
60
+ except requests.exceptions.RequestException as e:
61
+ return {}, e
src/{streamlit_app.py β†’ app.py} RENAMED
@@ -1,5 +1,7 @@
1
  import streamlit as st
2
- import requests
 
 
3
 
4
  # Page configuration
5
  st.set_page_config(
@@ -8,65 +10,6 @@ st.set_page_config(
8
  layout="wide"
9
  )
10
 
11
- def generate_sql(
12
- api_key: str,
13
- project_id: str,
14
- query: str,
15
- thread_id: str = "",
16
- ) -> dict:
17
- """Generate SQL from natural language query."""
18
- base_url = "https://cloud.getwren.ai/api/v1"
19
- endpoint = f"{base_url}/generate_sql"
20
- headers = {
21
- "Authorization": f"Bearer {api_key}",
22
- "Content-Type": "application/json"
23
- }
24
- payload = {
25
- "projectId": project_id,
26
- "question": query,
27
- }
28
- if thread_id:
29
- payload["threadId"] = thread_id
30
-
31
- try:
32
- response = requests.post(endpoint, json=payload, headers=headers)
33
- response.raise_for_status()
34
- return response.json()
35
- except requests.exceptions.RequestException as e:
36
- st.toast(f"Error generating SQL: {e}", icon="🚨")
37
- return {}
38
-
39
- def generate_chart(
40
- api_key: str,
41
- project_id: str,
42
- question: str,
43
- sql: str,
44
- sample_size: int = 1000,
45
- thread_id: str = ""
46
- ) -> dict:
47
- """Generate a chart from query results."""
48
- base_url = "https://cloud.getwren.ai/api/v1"
49
- endpoint = f"{base_url}/generate_vega_chart"
50
- headers = {
51
- "Authorization": f"Bearer {api_key}",
52
- "Content-Type": "application/json"
53
- }
54
- payload = {
55
- "projectId": project_id,
56
- "question": question,
57
- "sql": sql,
58
- "sampleSize": sample_size
59
- }
60
- if thread_id:
61
- payload["threadId"] = thread_id
62
-
63
- try:
64
- response = requests.post(endpoint, json=payload, headers=headers)
65
- response.raise_for_status()
66
- return response.json()
67
- except requests.exceptions.RequestException as e:
68
- st.toast(f"Error generating chart: {e}", icon="🚨")
69
- return {}
70
 
71
  def main():
72
  st.title("πŸ“Š Wren AI Cloud API Demo")
@@ -102,7 +45,7 @@ def main():
102
  st.warning("⚠️ Please enter your API Key and Project ID in the sidebar to get started.")
103
  st.info("""
104
  **How to get started:**
105
- 1. Enter your Wren AI API Key in the sidebar
106
  2. Enter your Project ID
107
  3. Ask questions about your data in natural language
108
  4. Get SQL queries and interactive charts automatically!
@@ -145,7 +88,7 @@ def main():
145
  # Generate response
146
  with st.chat_message("assistant"):
147
  with st.spinner("Generating SQL query..."):
148
- sql_response = generate_sql(api_key, project_id, prompt, st.session_state.thread_id)
149
 
150
  if sql_response:
151
  sql_query = sql_response.get("sql", "")
@@ -169,7 +112,7 @@ def main():
169
 
170
  # Generate chart
171
  with st.spinner("Generating chart..."):
172
- chart_response = generate_chart(api_key, project_id, prompt, sql_query, sample_size, st.session_state.thread_id)
173
 
174
  if chart_response:
175
  vega_spec = chart_response.get("vegaSpec", {})
@@ -195,7 +138,7 @@ def main():
195
  else:
196
  st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
197
  else:
198
- st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
199
  else:
200
  st.toast("No SQL query was generated. Please try rephrasing your question.", icon="🚨")
201
  assistant_message = {
@@ -204,6 +147,7 @@ def main():
204
  }
205
  st.session_state.messages.append(assistant_message)
206
  else:
 
207
  assistant_message = {
208
  "role": "assistant",
209
  "content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
@@ -216,5 +160,6 @@ def main():
216
  st.session_state.thread_id = ""
217
  st.rerun()
218
 
 
219
  if __name__ == "__main__":
220
  main()
 
1
  import streamlit as st
2
+
3
+ from apis import generate_sql, generate_chart
4
+
5
 
6
  # Page configuration
7
  st.set_page_config(
 
10
  layout="wide"
11
  )
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def main():
15
  st.title("πŸ“Š Wren AI Cloud API Demo")
 
45
  st.warning("⚠️ Please enter your API Key and Project ID in the sidebar to get started.")
46
  st.info("""
47
  **How to get started:**
48
+ 1. Enter your Wren AI Cloud API Key in the sidebar
49
  2. Enter your Project ID
50
  3. Ask questions about your data in natural language
51
  4. Get SQL queries and interactive charts automatically!
 
88
  # Generate response
89
  with st.chat_message("assistant"):
90
  with st.spinner("Generating SQL query..."):
91
+ sql_response, error = generate_sql(api_key, project_id, prompt, st.session_state.thread_id)
92
 
93
  if sql_response:
94
  sql_query = sql_response.get("sql", "")
 
112
 
113
  # Generate chart
114
  with st.spinner("Generating chart..."):
115
+ chart_response, error = generate_chart(api_key, project_id, prompt, sql_query, sample_size, st.session_state.thread_id)
116
 
117
  if chart_response:
118
  vega_spec = chart_response.get("vegaSpec", {})
 
138
  else:
139
  st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
140
  else:
141
+ st.toast(f"Failed to generate chart. Please check your query and try again.: {error}", icon="🚨")
142
  else:
143
  st.toast("No SQL query was generated. Please try rephrasing your question.", icon="🚨")
144
  assistant_message = {
 
147
  }
148
  st.session_state.messages.append(assistant_message)
149
  else:
150
+ st.toast(f"Error generating SQL: {error}", icon="🚨")
151
  assistant_message = {
152
  "role": "assistant",
153
  "content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
 
160
  st.session_state.thread_id = ""
161
  st.rerun()
162
 
163
+
164
  if __name__ == "__main__":
165
  main()