File size: 4,343 Bytes
b480321
 
3ddc773
 
 
 
b480321
 
3ddc773
b480321
 
 
4aa996b
 
 
b480321
350e55d
4aa996b
 
3ddc773
 
 
 
4aa996b
b480321
4aa996b
 
 
 
 
 
 
 
b480321
 
4aa996b
 
b480321
4aa996b
b480321
 
 
 
 
 
 
4aa996b
 
b480321
350e55d
4aa996b
b480321
 
 
4aa996b
 
3ddc773
 
4aa996b
3ddc773
 
 
 
 
 
4aa996b
 
3ddc773
 
 
 
 
 
 
 
 
 
 
4aa996b
3ddc773
 
4aa996b
3ddc773
 
4aa996b
3ddc773
 
4aa996b
3ddc773
 
 
 
 
 
 
 
b480321
3ddc773
 
 
 
 
 
 
 
350e55d
 
3ddc773
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"


# Database Path
db_path = "ecommerce.db"

# Ensure dataset exists
if not os.path.exists(db_path):
    print("Database file not found! Please upload ecommerce.db.")

# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")

# Updated Few-Shot Examples with SQLite-Compatible Queries
few_shot_examples = [
    {"input": "Find the busiest months for orders.",
     "output": "SELECT strftime('%m', order_purchase_timestamp) AS month, COUNT(*) AS order_count FROM orders GROUP BY month ORDER BY order_count DESC;"},
    {"input": "Show all customers from São Paulo.",
     "output": "SELECT * FROM customers WHERE customer_state = 'SP';"},
    {"input": "Find the total sales per product.",
     "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"},
    {"input": "List all orders placed in 2017.",
     "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"}
]

# Function: Convert Text to SQL

def text_to_sql(query):
    prompt = "Convert the following queries into SQLite-compatible SQL:\n\n"
    for example in few_shot_examples:
        prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n"
    prompt += f"Input: {query}\nOutput:"

    try:
        response = openai_client.chat.completions.create(
            model=OPENROUTER_MODEL,
            messages=[{"role": "system", "content": "You are an SQLite expert."},
                      {"role": "user", "content": prompt}]
        )
        sql_query = response.choices[0].message.content.strip()
        return sql_query if sql_query.lower().startswith("select") else f"Error: Invalid SQL generated - {sql_query}"
    except Exception as e:
        return f"Error: {e}"

# Function: Execute SQL on SQLite Database

def execute_sql(sql_query):
    try:
        conn = sqlite3.connect(db_path)
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        return df
    except Exception as e:
        return f"SQL Execution Error: {e}"

# Function: Generate Data Visualization

def visualize_data(df):
    if df.empty or df.shape[1] < 2:
        return None

    numeric_cols = df.select_dtypes(include=['number']).columns
    if len(numeric_cols) < 1:
        return None

    plt.figure(figsize=(6, 4))
    sns.set_theme(style="darkgrid")

    if len(numeric_cols) == 1:
        sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
        plt.title(f"Distribution of {numeric_cols[0]}")
    elif len(numeric_cols) == 2:
        sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
        plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
    elif df.shape[0] < 10:
        plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
        plt.title(f"Proportion of {numeric_cols[0]}")
    else:
        sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
        plt.xticks(rotation=45)
        plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")

    plt.tight_layout()
    plt.savefig("chart.png")
    return "chart.png"

# Gradio UI
def gradio_ui(query):
    sql_query = text_to_sql(query)
    results = execute_sql(sql_query)
    visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None

    return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization

with gr.Blocks() as demo:
    gr.Markdown("## SQL Explorer: Text to SQL with a Simple Visualization")
    query_input = gr.Textbox(label="Enter your query", placeholder="Enter your query in English.")
    submit_btn = gr.Button("Convert & Execute")
    sql_output = gr.Textbox(label="Generated SQL Query")
    table_output = gr.Textbox(label="Query Results")
    chart_output = gr.Image(label="Data Visualization")

    submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])

# Launch
demo.launch()