import gradio as gr import openai import sqlite3 import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import os # OpenRouter API Key (Replace with yours) OPENROUTER_API_KEY = "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" OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free" # Database Path db_path = "ecommerce.db" # Ensure dataset exists if not os.path.exists(db_path): print("Database file not found! Please upload ecommerce.db.") # Initialize OpenAI client openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1") # Updated Few-Shot Examples with SQLite-Compatible Queries few_shot_examples = [ {"input": "Find the busiest months for orders.", "output": "SELECT strftime('%m', order_purchase_timestamp) AS month, COUNT(*) AS order_count FROM orders GROUP BY month ORDER BY order_count DESC;"}, {"input": "Show all customers from São Paulo.", "output": "SELECT * FROM customers WHERE customer_state = 'SP';"}, {"input": "Find the total sales per product.", "output": "SELECT product_id, SUM(price) FROM order_items GROUP BY product_id;"}, {"input": "List all orders placed in 2017.", "output": "SELECT * FROM orders WHERE order_purchase_timestamp LIKE '2017%';"} ] # Function: Convert Text to SQL def text_to_sql(query): prompt = "Convert the following queries into SQLite-compatible SQL:\n\n" for example in few_shot_examples: prompt += f"Input: {example['input']}\nOutput: {example['output']}\n\n" prompt += f"Input: {query}\nOutput:" try: response = openai_client.chat.completions.create( model=OPENROUTER_MODEL, messages=[{"role": "system", "content": "You are an SQLite expert."}, {"role": "user", "content": prompt}] ) sql_query = response.choices[0].message.content.strip() return sql_query if sql_query.lower().startswith("select") else f"Error: Invalid SQL generated - {sql_query}" except Exception as e: return f"Error: {e}" # Function: Execute SQL on SQLite Database def execute_sql(sql_query): try: conn = sqlite3.connect(db_path) df = pd.read_sql_query(sql_query, conn) conn.close() return df except Exception as e: return f"SQL Execution Error: {e}" # Function: Generate Data Visualization def visualize_data(df): if df.empty or df.shape[1] < 2: return None numeric_cols = df.select_dtypes(include=['number']).columns if len(numeric_cols) < 1: return None plt.figure(figsize=(6, 4)) sns.set_theme(style="darkgrid") if len(numeric_cols) == 1: sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal") plt.title(f"Distribution of {numeric_cols[0]}") elif len(numeric_cols) == 2: sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue") plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}") elif df.shape[0] < 10: plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel")) plt.title(f"Proportion of {numeric_cols[0]}") else: sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm") plt.xticks(rotation=45) plt.title(f"{df.columns[0]} vs {numeric_cols[0]}") plt.tight_layout() plt.savefig("chart.png") return "chart.png" # Gradio UI def gradio_ui(query): sql_query = text_to_sql(query) results = execute_sql(sql_query) visualization = visualize_data(results) if isinstance(results, pd.DataFrame) else None return sql_query, results.to_string(index=False) if isinstance(results, pd.DataFrame) else results, visualization with gr.Blocks() as demo: gr.Markdown("## SQL Explorer: Text to SQL with a Simple Visualization") query_input = gr.Textbox(label="Enter your query", placeholder="Enter your query in English.") submit_btn = gr.Button("Convert & Execute") sql_output = gr.Textbox(label="Generated SQL Query") table_output = gr.Textbox(label="Query Results") chart_output = gr.Image(label="Data Visualization") submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output]) # Launch demo.launch()