srishtirai commited on
Commit
d4125bf
Β·
verified Β·
1 Parent(s): fb8e7e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -9
app.py CHANGED
@@ -1,15 +1,138 @@
1
-
2
  import gradio as gr
 
 
 
 
 
3
 
4
- def sql_generator(query):
5
- return f"Generated SQL Query for: {query}"
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface = gr.Interface(
8
- fn=sql_generator,
9
- inputs="text",
10
- outputs="text",
11
- title="SQL Query Generator",
12
- description="Enter a question to generate SQL."
 
 
 
 
 
 
 
 
 
13
  )
14
 
15
- iface.launch()
 
 
 
1
  import gradio as gr
2
+ import re
3
+ import torch
4
+ import sqlite3
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel, PeftConfig
7
 
8
+ # βœ… Load fine-tuned models from Hugging Face Model Hub instead of Kaggle paths
9
+ codellama_model_path = "srishtirai/codellama-sql-finetuned" # Upload to HF Model Hub
10
+ mistral_model_path = "srishtirai/mistral-sql-finetuned" # Upload to HF Model Hub
11
 
12
+ def load_model(model_path):
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ tokenizer.padding_side = "right"
16
+
17
+ peft_config = PeftConfig.from_pretrained(model_path)
18
+ base_model_name = peft_config.base_model_name_or_path
19
+ base_model = AutoModelForCausalLM.from_pretrained(
20
+ base_model_name,
21
+ torch_dtype=torch.float16,
22
+ device_map="auto"
23
+ )
24
+ model = PeftModel.from_pretrained(base_model, model_path)
25
+ model.eval()
26
+ return model, tokenizer
27
+
28
+ # βœ… Load both models from Hugging Face
29
+ codellama_model, codellama_tokenizer = load_model(codellama_model_path)
30
+ mistral_model, mistral_tokenizer = load_model(mistral_model_path)
31
+
32
+ # βœ… Function to format input
33
+ def format_input_prompt(schema, question):
34
+ return f"""### Context:
35
+ {schema}
36
+
37
+ ### Question:
38
+ {question}
39
+
40
+ ### Response:
41
+ Here's the SQL query:
42
+ """
43
+
44
+ # βœ… Function to generate SQL with explanation
45
+ def generate_sql_with_explanation(model_choice, schema, question, max_new_tokens=512, temperature=0.7):
46
+ """
47
+ Generate SQL query and explanation based on the selected model.
48
+ """
49
+ # Select model based on user choice
50
+ if model_choice == "CodeLlama":
51
+ model, tokenizer = codellama_model, codellama_tokenizer
52
+ else:
53
+ model, tokenizer = mistral_model, mistral_tokenizer
54
+
55
+ prompt = format_input_prompt(schema, question)
56
+
57
+ # Tokenize input
58
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
59
+
60
+ # Generate response
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_new_tokens=max_new_tokens,
65
+ do_sample=True,
66
+ temperature=temperature,
67
+ top_p=0.95,
68
+ pad_token_id=tokenizer.eos_token_id
69
+ )
70
+
71
+ # Decode generated text
72
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+
74
+ # Extract SQL query
75
+ sql_match = re.search(r'```sql\s*(.*?)\s*```', full_response, re.DOTALL)
76
+ sql_query = sql_match.group(1).strip() if sql_match else None
77
+
78
+ # Extract explanation
79
+ explanation_match = re.search(r'Explanation:\s*(.*?)($|\n\n)', full_response, re.DOTALL)
80
+ explanation = explanation_match.group(1).strip() if explanation_match else None
81
+
82
+ return {
83
+ "query": sql_query or "SQL query extraction failed.",
84
+ "explanation": explanation or "Explanation not found.",
85
+ "full_response": full_response
86
+ }
87
+
88
+ # βœ… Function to execute SQL query (Optional)
89
+ def execute_sql_query(sql_query):
90
+ """
91
+ Runs the generated SQL query on a sample SQLite database.
92
+ (Replace with a real DB connection if needed)
93
+ """
94
+ try:
95
+ conn = sqlite3.connect(":memory:") # Temporary SQLite DB
96
+ cursor = conn.cursor()
97
+ cursor.execute(sql_query)
98
+ result = cursor.fetchall()
99
+ conn.close()
100
+ return result if result else "Query executed successfully (No output rows)."
101
+ except Exception as e:
102
+ return f"Error executing SQL: {str(e)}"
103
+
104
+ # βœ… Gradio UI function
105
+ def gradio_generate_sql(model_choice, schema, question, run_sql):
106
+ """
107
+ Takes model selection, schema & question as input and returns SQL + explanation.
108
+ Optionally executes the SQL if requested.
109
+ """
110
+ result = generate_sql_with_explanation(model_choice, schema, question)
111
+ sql_query = result["query"]
112
+
113
+ if run_sql:
114
+ execution_result = execute_sql_query(sql_query)
115
+ return sql_query, result["explanation"], execution_result
116
+
117
+ return sql_query, result["explanation"], "SQL execution not requested."
118
+
119
+ # βœ… Gradio UI
120
  iface = gr.Interface(
121
+ fn=gradio_generate_sql,
122
+ inputs=[
123
+ gr.Dropdown(["CodeLlama", "Mistral"], label="Choose Model"),
124
+ gr.Textbox(label="Enter Database Schema", lines=10),
125
+ gr.Textbox(label="Enter your Question"),
126
+ gr.Checkbox(label="Run SQL Query?", value=False),
127
+ ],
128
+ outputs=[
129
+ gr.Code(label="Generated SQL Query", language="sql"), # SQL Syntax Highlighting
130
+ gr.Textbox(label="Explanation", lines=5),
131
+ gr.Textbox(label="SQL Execution Result", lines=5),
132
+ ],
133
+ title="SQL Query Generator with Execution",
134
+ description="Select a model, enter your database schema and question. Optionally, execute the generated SQL query.",
135
  )
136
 
137
+ # βœ… Launch Gradio
138
+ iface.launch()