import gradio as gr import re import torch import sqlite3 # Can be replaced with other DB connections import os from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel, PeftConfig from huggingface_hub import login # ✅ Retrieve Hugging Face token from environment variable (Secret) HF_TOKEN = os.getenv("HF_TOKEN") # ✅ Now using stored secret, not hardcoded! if HF_TOKEN is None: raise ValueError("❌ ERROR: Hugging Face token is missing! Please set HF_TOKEN in Hugging Face Secrets.") # ✅ Authenticate with Hugging Face login(token=HF_TOKEN) # ✅ Ensure offload directory exists os.makedirs("offload", exist_ok=True) # ✅ Load fine-tuned models from Hugging Face Model Hub codellama_model_path = "srishtirai/codellama-sql-finetuned" mistral_model_path = "srishtirai/mistral-sql-finetuned" def load_model(model_path): tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" peft_config = PeftConfig.from_pretrained(model_path) base_model_name = peft_config.base_model_name_or_path # ✅ Load base model with offloading & low-memory optimization base_model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, # Use FP16 to reduce memory usage device_map="auto", # Automatically distribute across CPU/GPU offload_folder="offload", # ✅ Prevents memory crashes use_auth_token=HF_TOKEN # ✅ Authenticate model loading ) # ✅ Load LoRA adapter with `is_trainable=False` model = PeftModel.from_pretrained( base_model, model_path, is_trainable=False # ✅ Fixes LoRA adapter loading issues ) model.eval() return model, tokenizer # ✅ Load both models from Hugging Face codellama_model, codellama_tokenizer = load_model(codellama_model_path) mistral_model, mistral_tokenizer = load_model(mistral_model_path) # ✅ Function to format input def format_input_prompt(schema, question): return f"""### Context: {schema} ### Question: {question} ### Response: Here's the SQL query: """ # ✅ Function to generate SQL with explanation def generate_sql_with_explanation(model_choice, schema, question, max_new_tokens=512, temperature=0.7): """ Generate SQL query and explanation based on the selected model. """ # Select model based on user choice if model_choice == "CodeLlama": model, tokenizer = codellama_model, codellama_tokenizer else: model, tokenizer = mistral_model, mistral_tokenizer prompt = format_input_prompt(schema, question) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=0.95, pad_token_id=tokenizer.eos_token_id ) # Decode generated text full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract SQL query sql_match = re.search(r'```sql\s*(.*?)\s*```', full_response, re.DOTALL) sql_query = sql_match.group(1).strip() if sql_match else None # Extract explanation explanation_match = re.search(r'Explanation:\s*(.*?)($|\n\n)', full_response, re.DOTALL) explanation = explanation_match.group(1).strip() if explanation_match else None return { "query": sql_query or "SQL query extraction failed.", "explanation": explanation or "Explanation not found.", "full_response": full_response } # ✅ Function to execute SQL query (if database connection is available) def execute_sql_query(sql_query): """ Runs the generated SQL query on a sample SQLite database. (You can replace SQLite with a connection to a real database) """ try: conn = sqlite3.connect(":memory:") # Temporary SQLite DB (Replace with actual DB connection) cursor = conn.cursor() cursor.execute(sql_query) result = cursor.fetchall() conn.close() return result if result else "Query executed successfully (No output rows)." except Exception as e: return f"Error executing SQL: {str(e)}" # ✅ Gradio UI function def gradio_generate_sql(model_choice, schema, question, run_sql): """ Takes model selection, schema & question as input and returns SQL + explanation. Optionally executes the SQL if requested. """ result = generate_sql_with_explanation(model_choice, schema, question) sql_query = result["query"] if run_sql: execution_result = execute_sql_query(sql_query) return sql_query, result["explanation"], execution_result return sql_query, result["explanation"], "SQL execution not requested." # ✅ Gradio UI iface = gr.Interface( fn=gradio_generate_sql, inputs=[ gr.Dropdown(["CodeLlama", "Mistral"], label="Choose Model"), gr.Textbox(label="Enter Database Schema", lines=10), gr.Textbox(label="Enter your Question"), gr.Checkbox(label="Run SQL Query?", value=False), ], outputs=[ gr.Code(label="Generated SQL Query", language="sql"), # SQL Syntax Highlighting gr.Textbox(label="Explanation", lines=5), gr.Textbox(label="SQL Execution Result", lines=5), ], title="SQL Query Generator with Execution", description="Select a model, enter your database schema and question. Optionally, execute the generated SQL query.", ) # ✅ Launch Gradio iface.launch()