nl2sql / app.py
jk12p's picture
Update app.py
991db49 verified
raw
history blame
9.45 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import sqlparse
# Load model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
"onkolahmet/Qwen2-0.5B-Instruct-SQL-generator",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator")
# # Few-shot examples to include in each prompt
# examples = [
# {
# "question": "Get the names and emails of customers who placed an order in the last 30 days.",
# "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
# },
# {
# "question": "Find all employees with a salary greater than 50000.",
# "sql": "SELECT * FROM employees WHERE salary > 50000;"
# },
# {
# "question": "List all product names and their categories where the price is below 50.",
# "sql": "SELECT name, category FROM products WHERE price < 50;"
# },
# {
# "question": "How many users registered in the year 2022?",
# "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
# }
# ]
def generate_sql(question, context=None):
# Construct prompt with few-shot examples and context if available
prompt = "Translate natural language questions to SQL queries.\n\n"
# Add table context if available
if context and context.strip():
prompt += f"Table Context:\n{context}\n\n"
# # Add few-shot examples
# for ex in examples:
# prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
# Add the current question
prompt += f"Q: {question}\nSQL:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate SQL query
outputs = model.generate(
inputs.input_ids,
max_new_tokens=128,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
# Extract and decode only the new generation
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return sql_query.strip()
def clean_sql_output(sql_text):
"""
Clean and deduplicate SQL queries:
1. Remove comments
2. Remove duplicate queries
3. Extract only the most relevant query
4. Format properly
"""
# Remove SQL comments (both single line and multi-line)
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# Remove markdown code block syntax if present
sql_text = re.sub(r'```sql|```', '', sql_text)
# Split into individual queries if multiple exist
if ';' in sql_text:
queries = [q.strip() for q in sql_text.split(';') if q.strip()]
else:
# If no semicolons, try to identify separate queries by SELECT statements
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
if len(select_matches) > 1:
queries = []
for i in range(len(select_matches)):
start = select_matches[i].start()
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
queries.append(sql_text_cleaned[start:end].strip())
else:
queries = [sql_text]
# Remove empty queries
queries = [q for q in queries if q.strip()]
if not queries:
return ""
# If we have multiple queries, need to deduplicate
if len(queries) > 1:
# Normalize queries for comparison (lowercase, remove extra spaces)
normalized_queries = []
for q in queries:
# Use sqlparse to format and normalize
try:
formatted = sqlparse.format(
q + ('' if q.strip().endswith(';') else ';'),
keyword_case='lower',
identifier_case='lower',
strip_comments=True,
reindent=True
)
normalized_queries.append(formatted)
except:
# If sqlparse fails, just do basic normalization
normalized = re.sub(r'\s+', ' ', q.lower().strip())
normalized_queries.append(normalized)
# Find unique queries
unique_queries = []
unique_normalized = []
for i, norm_q in enumerate(normalized_queries):
if norm_q not in unique_normalized:
unique_normalized.append(norm_q)
unique_queries.append(queries[i])
# Choose the most likely correct query:
# 1. Prefer queries with SELECT
# 2. Prefer longer queries (often more detailed)
# 3. Prefer first query if all else equal
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
if select_queries:
# Choose the longest SELECT query (likely most detailed)
best_query = max(select_queries, key=len)
elif unique_queries:
# If no SELECT queries, choose the longest query
best_query = max(unique_queries, key=len)
else:
# Fallback to the first query
best_query = queries[0]
else:
best_query = queries[0]
# Clean up the chosen query
best_query = best_query.strip()
if not best_query.endswith(';'):
best_query += ';'
# Final formatting to ensure consistent spacing
best_query = re.sub(r'\s+', ' ', best_query)
try:
# Use sqlparse to nicely format the SQL for display
formatted_sql = sqlparse.format(
best_query,
keyword_case='upper',
identifier_case='lower',
reindent=True,
indent_width=2
)
return formatted_sql
except:
return best_query
def process_input(question, table_context):
"""Function to process user input through the model and return formatted results"""
if not question.strip():
return "Please enter a question."
# Generate SQL from the question and context
raw_sql = generate_sql(question, table_context)
# Clean the SQL output
cleaned_sql = clean_sql_output(raw_sql)
if not cleaned_sql:
return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question."
return cleaned_sql
# Sample table context examples for the example selector
# Sample question examples
# Create the Gradio interface
with gr.Blocks(title="Text to SQL Converter") as demo:
gr.Markdown("# Text to SQL Query Converter")
gr.Markdown("Enter your question and optional table context to generate an SQL query.")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Your Question",
placeholder="e.g., Find all products with price less than $50",
lines=2
)
table_context = gr.Textbox(
label="Table Context (Optional)",
placeholder="Enter your database schema or table definitions here...",
lines=10
)
submit_btn = gr.Button("Generate SQL Query")
with gr.Column():
sql_output = gr.Code(
label="Generated SQL Query",
language="sql",
lines=12
)
# Examples section
gr.Markdown("### Try some examples")
example_selector = gr.Examples(
examples=[
["List all products in the 'Electronics' category with price less than $500", example_contexts[1]],
["Find the total number of employees in each department", example_contexts[2]],
["Get customers who placed orders in the last 7 days", example_contexts[0]],
["Count the number of products in each category", example_contexts[1]],
["Find the average salary by department", example_contexts[2]]
],
inputs=[question_input, table_context]
)
# Set up the submit button to trigger the process_input function
submit_btn.click(
fn=process_input,
inputs=[question_input, table_context],
outputs=sql_output
)
# Also trigger on pressing Enter in the question input
question_input.submit(
fn=process_input,
inputs=[question_input, table_context],
outputs=sql_output
)
# Add information about the model
gr.Markdown("""
### About
This app uses a fine-tuned language model to convert natural language questions into SQL queries.
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
- **How to use**:
1. Enter your question in natural language
2. If you have specific table schemas, add them in the Table Context field
3. Click "Generate SQL Query" or press Enter
Note: The model works best when table context is provided, but can generate generic SQL queries without it.
""")
# Launch the app
demo.launch()