TextToSQL / app.py
thechaiexperiment's picture
Update app.py
4aa996b verified
raw
history blame
4.34 kB
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
# Database Path
db_path = "ecommerce.db"
# Ensure dataset exists
if not os.path.exists(db_path):
print("Database file not found! Please upload ecommerce.db.")
# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
# Updated Few-Shot Examples with SQLite-Compatible Queries
few_shot_examples = [
{"input": "Find the busiest months for orders.",
"output": "SELECT strftime('%m', order_purchase_timestamp) AS month, COUNT(*) AS order_count FROM orders GROUP BY month ORDER BY order_count DESC;"},
{"input": "Show all customers from São Paulo.",
"output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
{"input": "Find the total sales per product.",
"output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
{"input": "List all orders placed in 2017.",
"output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
]
# Function: Convert Text to SQL
def text_to_sql(query):
prompt = "Convert the following queries into SQLite-compatible SQL:\n\n"
for example in few_shot_examples:
prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
prompt += f"Input: {query}\nOutput:"
try:
response = openai_client.chat.completions.create(
model=OPENROUTER_MODEL,
messages=[{"role": "system", "content": "You are an SQLite expert."},
{"role": "user", "content": prompt}]
)
sql_query = response.choices[0].message.content.strip()
return sql_query if sql_query.lower().startswith("select") else f"Error: Invalid SQL generated - {sql_query}"
except Exception as e:
return f"Error: {e}"
# Function: Execute SQL on SQLite Database
def execute_sql(sql_query):
try:
conn = sqlite3.connect(db_path)
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df
except Exception as e:
return f"SQL Execution Error: {e}"
# Function: Generate Data Visualization
def visualize_data(df):
if df.empty or df.shape[1] < 2:
return None
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) < 1:
return None
plt.figure(figsize=(6, 4))
sns.set_theme(style="darkgrid")
if len(numeric_cols) == 1:
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
plt.title(f"Distribution of {numeric_cols[0]}")
elif len(numeric_cols) == 2:
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
elif df.shape[0] < 10:
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
plt.title(f"Proportion of {numeric_cols[0]}")
else:
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
plt.xticks(rotation=45)
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
plt.tight_layout()
plt.savefig("chart.png")
return "chart.png"
# Gradio UI
def gradio_ui(query):
sql_query = text_to_sql(query)
results = execute_sql(sql_query)
visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None
return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
with gr.Blocks() as demo:
gr.Markdown("## SQL Explorer: Text to SQL with a Simple Visualization")
query_input = gr.Textbox(label="Enter your query", placeholder="Enter your query in English.")
submit_btn = gr.Button("Convert & Execute")
sql_output = gr.Textbox(label="Generated SQL Query")
table_output = gr.Textbox(label="Query Results")
chart_output = gr.Image(label="Data Visualization")
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
# Launch
demo.launch()