Spaces:
Running
Running
File size: 4,343 Bytes
b480321 3ddc773 b480321 3ddc773 b480321 4aa996b b480321 350e55d 4aa996b 3ddc773 4aa996b b480321 4aa996b b480321 4aa996b b480321 4aa996b b480321 4aa996b b480321 350e55d 4aa996b b480321 4aa996b 3ddc773 4aa996b 3ddc773 4aa996b 3ddc773 4aa996b 3ddc773 4aa996b 3ddc773 4aa996b 3ddc773 4aa996b 3ddc773 b480321 3ddc773 350e55d 3ddc773 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
# Database Path
db_path = "ecommerce.db"
# Ensure dataset exists
if not os.path.exists(db_path):
print("Database file not found! Please upload ecommerce.db.")
# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
# Updated Few-Shot Examples with SQLite-Compatible Queries
few_shot_examples = [
{"input": "Find the busiest months for orders.",
"output": "SELECT strftime('%m', order_purchase_timestamp) AS month, COUNT(*) AS order_count FROM orders GROUP BY month ORDER BY order_count DESC;"},
{"input": "Show all customers from São Paulo.",
"output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
{"input": "Find the total sales per product.",
"output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
{"input": "List all orders placed in 2017.",
"output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
]
# Function: Convert Text to SQL
def text_to_sql(query):
prompt = "Convert the following queries into SQLite-compatible SQL:\n\n"
for example in few_shot_examples:
prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
prompt += f"Input: {query}\nOutput:"
try:
response = openai_client.chat.completions.create(
model=OPENROUTER_MODEL,
messages=[{"role": "system", "content": "You are an SQLite expert."},
{"role": "user", "content": prompt}]
)
sql_query = response.choices[0].message.content.strip()
return sql_query if sql_query.lower().startswith("select") else f"Error: Invalid SQL generated - {sql_query}"
except Exception as e:
return f"Error: {e}"
# Function: Execute SQL on SQLite Database
def execute_sql(sql_query):
try:
conn = sqlite3.connect(db_path)
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df
except Exception as e:
return f"SQL Execution Error: {e}"
# Function: Generate Data Visualization
def visualize_data(df):
if df.empty or df.shape[1] < 2:
return None
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) < 1:
return None
plt.figure(figsize=(6, 4))
sns.set_theme(style="darkgrid")
if len(numeric_cols) == 1:
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
plt.title(f"Distribution of {numeric_cols[0]}")
elif len(numeric_cols) == 2:
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
elif df.shape[0] < 10:
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
plt.title(f"Proportion of {numeric_cols[0]}")
else:
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
plt.xticks(rotation=45)
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
plt.tight_layout()
plt.savefig("chart.png")
return "chart.png"
# Gradio UI
def gradio_ui(query):
sql_query = text_to_sql(query)
results = execute_sql(sql_query)
visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None
return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
with gr.Blocks() as demo:
gr.Markdown("## SQL Explorer: Text to SQL with a Simple Visualization")
query_input = gr.Textbox(label="Enter your query", placeholder="Enter your query in English.")
submit_btn = gr.Button("Convert & Execute")
sql_output = gr.Textbox(label="Generated SQL Query")
table_output = gr.Textbox(label="Query Results")
chart_output = gr.Image(label="Data Visualization")
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
# Launch
demo.launch()
|