Spaces:
Runtime error
Runtime error
import gradio as gr | |
import re | |
import torch | |
import sqlite3 # Can be replaced with other DB connections | |
import os | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel, PeftConfig | |
from huggingface_hub import login | |
# β Retrieve Hugging Face token from environment variable (Secret) | |
HF_TOKEN = os.getenv("HF_TOKEN") # β Now using stored secret, not hardcoded! | |
if HF_TOKEN is None: | |
raise ValueError("β ERROR: Hugging Face token is missing! Please set HF_TOKEN in Hugging Face Secrets.") | |
# β Authenticate with Hugging Face | |
login(token=HF_TOKEN) | |
# β Ensure offload directory exists | |
os.makedirs("offload", exist_ok=True) | |
# β Load fine-tuned models from Hugging Face Model Hub | |
codellama_model_path = "srishtirai/codellama-sql-finetuned" | |
mistral_model_path = "srishtirai/mistral-sql-finetuned" | |
def load_model(model_path): | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
peft_config = PeftConfig.from_pretrained(model_path) | |
base_model_name = peft_config.base_model_name_or_path | |
# β Load base model with offloading & low-memory optimization | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.float16, # Use FP16 to reduce memory usage | |
device_map="auto", # Automatically distribute across CPU/GPU | |
offload_folder="offload", # β Prevents memory crashes | |
use_auth_token=HF_TOKEN # β Authenticate model loading | |
) | |
# β Load LoRA adapter with `is_trainable=False` | |
model = PeftModel.from_pretrained( | |
base_model, | |
model_path, | |
is_trainable=False # β Fixes LoRA adapter loading issues | |
) | |
model.eval() | |
return model, tokenizer | |
# β Load both models from Hugging Face | |
codellama_model, codellama_tokenizer = load_model(codellama_model_path) | |
mistral_model, mistral_tokenizer = load_model(mistral_model_path) | |
# β Function to format input | |
def format_input_prompt(schema, question): | |
return f"""### Context: | |
{schema} | |
### Question: | |
{question} | |
### Response: | |
Here's the SQL query: | |
""" | |
# β Function to generate SQL with explanation | |
def generate_sql_with_explanation(model_choice, schema, question, max_new_tokens=512, temperature=0.7): | |
""" | |
Generate SQL query and explanation based on the selected model. | |
""" | |
# Select model based on user choice | |
if model_choice == "CodeLlama": | |
model, tokenizer = codellama_model, codellama_tokenizer | |
else: | |
model, tokenizer = mistral_model, mistral_tokenizer | |
prompt = format_input_prompt(schema, question) | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=0.95, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode generated text | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract SQL query | |
sql_match = re.search(r'```sql\s*(.*?)\s*```', full_response, re.DOTALL) | |
sql_query = sql_match.group(1).strip() if sql_match else None | |
# Extract explanation | |
explanation_match = re.search(r'Explanation:\s*(.*?)($|\n\n)', full_response, re.DOTALL) | |
explanation = explanation_match.group(1).strip() if explanation_match else None | |
return { | |
"query": sql_query or "SQL query extraction failed.", | |
"explanation": explanation or "Explanation not found.", | |
"full_response": full_response | |
} | |
# β Function to execute SQL query (if database connection is available) | |
def execute_sql_query(sql_query): | |
""" | |
Runs the generated SQL query on a sample SQLite database. | |
(You can replace SQLite with a connection to a real database) | |
""" | |
try: | |
conn = sqlite3.connect(":memory:") # Temporary SQLite DB (Replace with actual DB connection) | |
cursor = conn.cursor() | |
cursor.execute(sql_query) | |
result = cursor.fetchall() | |
conn.close() | |
return result if result else "Query executed successfully (No output rows)." | |
except Exception as e: | |
return f"Error executing SQL: {str(e)}" | |
# β Gradio UI function | |
def gradio_generate_sql(model_choice, schema, question, run_sql): | |
""" | |
Takes model selection, schema & question as input and returns SQL + explanation. | |
Optionally executes the SQL if requested. | |
""" | |
result = generate_sql_with_explanation(model_choice, schema, question) | |
sql_query = result["query"] | |
if run_sql: | |
execution_result = execute_sql_query(sql_query) | |
return sql_query, result["explanation"], execution_result | |
return sql_query, result["explanation"], "SQL execution not requested." | |
# β Gradio UI | |
iface = gr.Interface( | |
fn=gradio_generate_sql, | |
inputs=[ | |
gr.Dropdown(["CodeLlama", "Mistral"], label="Choose Model"), | |
gr.Textbox(label="Enter Database Schema", lines=10), | |
gr.Textbox(label="Enter your Question"), | |
gr.Checkbox(label="Run SQL Query?", value=False), | |
], | |
outputs=[ | |
gr.Code(label="Generated SQL Query", language="sql"), # SQL Syntax Highlighting | |
gr.Textbox(label="Explanation", lines=5), | |
gr.Textbox(label="SQL Execution Result", lines=5), | |
], | |
title="SQL Query Generator with Execution", | |
description="Select a model, enter your database schema and question. Optionally, execute the generated SQL query.", | |
) | |
# β Launch Gradio | |
iface.launch() |