File size: 4,413 Bytes
b480321
 
3ddc773
 
 
 
b480321
 
3ddc773
b480321
 
 
08e4afd
 
b480321
350e55d
08e4afd
 
3ddc773
 
 
 
08e4afd
b480321
08e4afd
 
 
b480321
 
08e4afd
b480321
08e4afd
b480321
 
 
 
 
 
 
08e4afd
b480321
350e55d
08e4afd
 
 
 
b480321
 
 
08e4afd
3ddc773
 
08e4afd
3ddc773
 
 
 
 
 
08e4afd
3ddc773
 
 
 
08e4afd
3ddc773
 
 
 
 
 
 
08e4afd
 
3ddc773
 
08e4afd
3ddc773
 
08e4afd
3ddc773
 
08e4afd
3ddc773
 
 
 
 
 
 
 
b480321
3ddc773
 
 
 
 
 
 
 
08e4afd
 
3ddc773
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"

# Hugging Face Space path
DB_PATH = "ecommerce.db"

# Ensure dataset exists
if not os.path.exists(DB_PATH):
    os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db")  # Replace with actual dataset link

# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")

# Few-shot examples for text-to-SQL
few_shot_examples = [
    {"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
    {"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
    {"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
]

# Function: Convert text to SQL
def text_to_sql(query):
    prompt = "Convert the following queries into SQL:\n\n"
    for example in few_shot_examples:
        prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
    prompt += f"Input: {query}\nOutput:"

    try:
        response = openai_client.chat.completions.create(
            model=OPENROUTER_MODEL,
            messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
        )
        sql_query = response.choices[0].message.content.strip()
        
        # Ensure only one query is returned (remove extra text)
        sql_query = sql_query.split("\n")[0].strip()
        return sql_query
    except Exception as e:
        return f"Error: {e}"

# Function: Execute SQL on SQLite database
def execute_sql(sql_query):
    try:
        conn = sqlite3.connect(DB_PATH)
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        return df
    except Exception as e:
        return f"SQL Execution Error: {e}"

# Function: Generate Dynamic Visualization
def visualize_data(df):
    if df.empty or df.shape[1] < 2:
        return None

    # Detect numeric columns
    numeric_cols = df.select_dtypes(include=['number']).columns
    if len(numeric_cols) < 1:
        return None

    plt.figure(figsize=(6, 4))
    sns.set_theme(style="darkgrid")

    # Choose visualization type dynamically
    if len(numeric_cols) == 1:  # Single numeric column, assume it's a count metric
        sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
        plt.title(f"Distribution of {numeric_cols[0]}")
    elif len(numeric_cols) == 2:  # Two numeric columns, assume X-Y plot
        sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
        plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
    elif df.shape[0] < 10:  # If rows are few, prefer pie chart
        plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
        plt.title(f"Proportion of {numeric_cols[0]}")
    else:  # Default: Bar chart for categories + values
        sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
        plt.xticks(rotation=45)
        plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")

    plt.tight_layout()
    plt.savefig("chart.png")
    return "chart.png"

# Gradio UI
def gradio_ui(query):
    sql_query = text_to_sql(query)
    results = execute_sql(sql_query)
    visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None

    return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization

with gr.Blocks() as demo:
    gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
    query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
    submit_btn = gr.Button("Convert & Execute")
    sql_output = gr.Textbox(label="Generated SQL Query")
    table_output = gr.Textbox(label="Query Results")
    chart_output = gr.Image(label="Data Visualization")

    submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])

# Launch
demo.launch()