Spaces:
Running
Running
File size: 4,257 Bytes
b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 b480321 3ddc773 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
# Hugging Face Space path
DB_PATH = "ecommerce.db"
# Ensure dataset exists
if not os.path.exists(DB_PATH):
os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
# Few-shot examples for text-to-SQL
few_shot_examples = [
{"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
{"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
{"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
]
# Function: Convert text to SQL
def text_to_sql(query):
prompt = "Convert the following queries into SQL:\n\n"
for example in few_shot_examples:
prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
prompt += f"Input: {query}\nOutput:"
try:
response = openai_client.chat.completions.create(
model=OPENROUTER_MODEL,
messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error: {e}"
# Function: Execute SQL on SQLite database
def execute_sql(sql_query):
try:
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df
except Exception as e:
return f"SQL Execution Error: {e}"
# Function: Generate Dynamic Visualization
def visualize_data(df):
if df.empty or df.shape[1] < 2:
return None
# Detect numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) < 1:
return None
plt.figure(figsize=(6, 4))
sns.set_theme(style="darkgrid")
# Choose visualization type dynamically
if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
plt.title(f"Distribution of {numeric_cols[0]}")
elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
elif df.shape[0] < 10: # If rows are few, prefer pie chart
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
plt.title(f"Proportion of {numeric_cols[0]}")
else: # Default: Bar chart for categories + values
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
plt.xticks(rotation=45)
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
plt.tight_layout()
plt.savefig("chart.png")
return "chart.png"
# Gradio UI
def gradio_ui(query):
sql_query = text_to_sql(query)
results = execute_sql(sql_query)
visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None
return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization
with gr.Blocks() as demo:
gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
submit_btn = gr.Button("Convert & Execute")
sql_output = gr.Textbox(label="Generated SQL Query")
table_output = gr.Textbox(label="Query Results")
chart_output = gr.Image(label="Data Visualization")
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
# Launch
demo.launch()
|