Spaces:
Running
Running
File size: 4,647 Bytes
b480321 3ddc773 b480321 79ceb52 b480321 3ddc773 b480321 08e4afd b480321 350e55d 08e4afd 3ddc773 79ceb52 b480321 08e4afd 79ceb52 b480321 08e4afd b480321 350e55d 08e4afd b480321 08e4afd 79ceb52 3ddc773 08e4afd 3ddc773 79ceb52 3ddc773 79ceb52 3ddc773 08e4afd 79ceb52 3ddc773 79ceb52 08e4afd 3ddc773 08e4afd 3ddc773 08e4afd 3ddc773 08e4afd 3ddc773 08e4afd 3ddc773 b480321 79ceb52 3ddc773 08e4afd 3ddc773 79ceb52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import openai
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Optional, Tuple
# OpenRouter API Key (Replace with yours)
OPENROUTER_API_KEY = "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"
OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free"
# Hugging Face Space path
DB_PATH = "ecommerce.db"
# Ensure dataset exists
if not os.path.exists(DB_PATH):
os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link
# Initialize OpenAI client
openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1")
# Function: Fetch database schema
def fetch_schema(db_path: str) -> str:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema = ""
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
schema += f"Table: {table_name}\n"
for column in columns:
schema += f" Column: {column[1]}, Type: {column[2]}\n"
conn.close()
return schema
# Function: Convert text to SQL
def text_to_sql(query: str, schema: str) -> str:
prompt = (
"You are an SQL expert. Given the following database schema:\n\n"
f"{schema}\n\n"
"Convert the following query into SQL:\n\n"
f"Query: {query}\n"
"SQL:"
)
try:
response = openai_client.chat.completions.create(
model=OPENROUTER_MODEL,
messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}]
)
sql_query = response.choices[0].message.content.strip()
return sql_query
except Exception as e:
return f"Error: {e}"
# Function: Execute SQL on SQLite database
def execute_sql(sql_query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
try:
conn = sqlite3.connect(DB_PATH)
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df, None
except Exception as e:
return None, f"SQL Execution Error: {e}"
# Function: Generate Dynamic Visualization
def visualize_data(df: pd.DataFrame) -> Optional[str]:
if df.empty or df.shape[1] < 2:
return None
plt.figure(figsize=(6, 4))
sns.set_theme(style="darkgrid")
# Detect numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
if len(numeric_cols) < 1:
return None
# Choose visualization type dynamically
if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric
sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal")
plt.title(f"Distribution of {numeric_cols[0]}")
elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot
sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue")
plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}")
elif df.shape[0] < 10: # If rows are few, prefer pie chart
plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel"))
plt.title(f"Proportion of {numeric_cols[0]}")
else: # Default: Bar chart for categories + values
sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm")
plt.xticks(rotation=45)
plt.title(f"{df.columns[0]} vs {numeric_cols[0]}")
plt.tight_layout()
plt.savefig("chart.png")
return "chart.png"
# Gradio UI
def gradio_ui(query: str) -> Tuple[str, str, Optional[str]]:
schema = fetch_schema(DB_PATH)
sql_query = text_to_sql(query, schema)
df, error = execute_sql(sql_query)
if error:
return sql_query, error, None
visualization = visualize_data(df) if df is not None else None
return sql_query, df.to_string(index=False), visualization
# Launch Gradio App
with gr.Blocks() as demo:
gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization")
query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.")
submit_btn = gr.Button("Convert & Execute")
sql_output = gr.Textbox(label="Generated SQL Query")
table_output = gr.Textbox(label="Query Results")
chart_output = gr.Image(label="Data Visualization")
submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output])
demo.launch() |