srishtirai's picture
Update app.py
f24c38b verified
import gradio as gr
import re
import torch
import sqlite3 # Can be replaced with other DB connections
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from huggingface_hub import login
# βœ… Retrieve Hugging Face token from environment variable (Secret)
HF_TOKEN = os.getenv("HF_TOKEN") # βœ… Now using stored secret, not hardcoded!
if HF_TOKEN is None:
raise ValueError("❌ ERROR: Hugging Face token is missing! Please set HF_TOKEN in Hugging Face Secrets.")
# βœ… Authenticate with Hugging Face
login(token=HF_TOKEN)
# βœ… Ensure offload directory exists
os.makedirs("offload", exist_ok=True)
# βœ… Load fine-tuned models from Hugging Face Model Hub
codellama_model_path = "srishtirai/codellama-sql-finetuned"
mistral_model_path = "srishtirai/mistral-sql-finetuned"
def load_model(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
peft_config = PeftConfig.from_pretrained(model_path)
base_model_name = peft_config.base_model_name_or_path
# βœ… Load base model with offloading & low-memory optimization
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16, # Use FP16 to reduce memory usage
device_map="auto", # Automatically distribute across CPU/GPU
offload_folder="offload", # βœ… Prevents memory crashes
use_auth_token=HF_TOKEN # βœ… Authenticate model loading
)
# βœ… Load LoRA adapter with `is_trainable=False`
model = PeftModel.from_pretrained(
base_model,
model_path,
is_trainable=False # βœ… Fixes LoRA adapter loading issues
)
model.eval()
return model, tokenizer
# βœ… Load both models from Hugging Face
codellama_model, codellama_tokenizer = load_model(codellama_model_path)
mistral_model, mistral_tokenizer = load_model(mistral_model_path)
# βœ… Function to format input
def format_input_prompt(schema, question):
return f"""### Context:
{schema}
### Question:
{question}
### Response:
Here's the SQL query:
"""
# βœ… Function to generate SQL with explanation
def generate_sql_with_explanation(model_choice, schema, question, max_new_tokens=512, temperature=0.7):
"""
Generate SQL query and explanation based on the selected model.
"""
# Select model based on user choice
if model_choice == "CodeLlama":
model, tokenizer = codellama_model, codellama_tokenizer
else:
model, tokenizer = mistral_model, mistral_tokenizer
prompt = format_input_prompt(schema, question)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
# Decode generated text
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract SQL query
sql_match = re.search(r'```sql\s*(.*?)\s*```', full_response, re.DOTALL)
sql_query = sql_match.group(1).strip() if sql_match else None
# Extract explanation
explanation_match = re.search(r'Explanation:\s*(.*?)($|\n\n)', full_response, re.DOTALL)
explanation = explanation_match.group(1).strip() if explanation_match else None
return {
"query": sql_query or "SQL query extraction failed.",
"explanation": explanation or "Explanation not found.",
"full_response": full_response
}
# βœ… Function to execute SQL query (if database connection is available)
def execute_sql_query(sql_query):
"""
Runs the generated SQL query on a sample SQLite database.
(You can replace SQLite with a connection to a real database)
"""
try:
conn = sqlite3.connect(":memory:") # Temporary SQLite DB (Replace with actual DB connection)
cursor = conn.cursor()
cursor.execute(sql_query)
result = cursor.fetchall()
conn.close()
return result if result else "Query executed successfully (No output rows)."
except Exception as e:
return f"Error executing SQL: {str(e)}"
# βœ… Gradio UI function
def gradio_generate_sql(model_choice, schema, question, run_sql):
"""
Takes model selection, schema & question as input and returns SQL + explanation.
Optionally executes the SQL if requested.
"""
result = generate_sql_with_explanation(model_choice, schema, question)
sql_query = result["query"]
if run_sql:
execution_result = execute_sql_query(sql_query)
return sql_query, result["explanation"], execution_result
return sql_query, result["explanation"], "SQL execution not requested."
# βœ… Gradio UI
iface = gr.Interface(
fn=gradio_generate_sql,
inputs=[
gr.Dropdown(["CodeLlama", "Mistral"], label="Choose Model"),
gr.Textbox(label="Enter Database Schema", lines=10),
gr.Textbox(label="Enter your Question"),
gr.Checkbox(label="Run SQL Query?", value=False),
],
outputs=[
gr.Code(label="Generated SQL Query", language="sql"), # SQL Syntax Highlighting
gr.Textbox(label="Explanation", lines=5),
gr.Textbox(label="SQL Execution Result", lines=5),
],
title="SQL Query Generator with Execution",
description="Select a model, enter your database schema and question. Optionally, execute the generated SQL query.",
)
# βœ… Launch Gradio
iface.launch()