srishtirai commited on
Commit
55215ba
Β·
verified Β·
1 Parent(s): 81657a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -1,17 +1,17 @@
1
  import gradio as gr
2
  import re
3
  import torch
4
- import sqlite3
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from peft import PeftModel, PeftConfig
7
- import torch
8
- import os
9
- os.makedirs("offload", exist_ok=True)
10
 
11
- # βœ… Load fine-tuned models from Hugging Face Model Hub instead of Kaggle paths
12
- codellama_model_path = "srishtirai/codellama-sql-finetuned" # Upload to HF Model Hub
13
- mistral_model_path = "srishtirai/mistral-sql-finetuned" # Upload to HF Model Hub
14
 
 
 
 
15
 
16
  def load_model(model_path):
17
  tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -21,18 +21,24 @@ def load_model(model_path):
21
  peft_config = PeftConfig.from_pretrained(model_path)
22
  base_model_name = peft_config.base_model_name_or_path
23
 
 
24
  base_model = AutoModelForCausalLM.from_pretrained(
25
  base_model_name,
26
- torch_dtype=torch.float16, # Use FP16 to save memory
27
- device_map="auto", # Automatically allocate layers to CPU/GPU
28
- offload_folder="offload" # βœ… Offload large layers to disk
 
 
 
 
 
 
 
29
  )
30
 
31
- model = PeftModel.from_pretrained(base_model, model_path)
32
  model.eval()
33
  return model, tokenizer
34
 
35
-
36
  # βœ… Load both models from Hugging Face
37
  codellama_model, codellama_tokenizer = load_model(codellama_model_path)
38
  mistral_model, mistral_tokenizer = load_model(mistral_model_path)
@@ -93,14 +99,14 @@ def generate_sql_with_explanation(model_choice, schema, question, max_new_tokens
93
  "full_response": full_response
94
  }
95
 
96
- # βœ… Function to execute SQL query (Optional)
97
  def execute_sql_query(sql_query):
98
  """
99
  Runs the generated SQL query on a sample SQLite database.
100
- (Replace with a real DB connection if needed)
101
  """
102
  try:
103
- conn = sqlite3.connect(":memory:") # Temporary SQLite DB
104
  cursor = conn.cursor()
105
  cursor.execute(sql_query)
106
  result = cursor.fetchall()
 
1
  import gradio as gr
2
  import re
3
  import torch
4
+ import sqlite3 # Can be replaced with other DB connections
5
+ import os
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from peft import PeftModel, PeftConfig
 
 
 
8
 
9
+ # βœ… Ensure offload directory exists
10
+ os.makedirs("offload", exist_ok=True)
 
11
 
12
+ # βœ… Load fine-tuned models from Hugging Face Model Hub
13
+ codellama_model_path = "srishtirai/codellama-sql-finetuned"
14
+ mistral_model_path = "srishtirai/mistral-sql-finetuned"
15
 
16
  def load_model(model_path):
17
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
21
  peft_config = PeftConfig.from_pretrained(model_path)
22
  base_model_name = peft_config.base_model_name_or_path
23
 
24
+ # βœ… Load base model with offloading & low-memory optimization
25
  base_model = AutoModelForCausalLM.from_pretrained(
26
  base_model_name,
27
+ torch_dtype=torch.float16, # Use FP16 to reduce memory usage
28
+ device_map="auto", # Automatically distribute across CPU/GPU
29
+ offload_folder="offload" # βœ… Prevents memory crashes
30
+ )
31
+
32
+ # βœ… Load LoRA adapter with `is_trainable=False`
33
+ model = PeftModel.from_pretrained(
34
+ base_model,
35
+ model_path,
36
+ is_trainable=False # βœ… Fixes LoRA adapter loading issues
37
  )
38
 
 
39
  model.eval()
40
  return model, tokenizer
41
 
 
42
  # βœ… Load both models from Hugging Face
43
  codellama_model, codellama_tokenizer = load_model(codellama_model_path)
44
  mistral_model, mistral_tokenizer = load_model(mistral_model_path)
 
99
  "full_response": full_response
100
  }
101
 
102
+ # βœ… Function to execute SQL query (if database connection is available)
103
  def execute_sql_query(sql_query):
104
  """
105
  Runs the generated SQL query on a sample SQLite database.
106
+ (You can replace SQLite with a connection to a real database)
107
  """
108
  try:
109
+ conn = sqlite3.connect(":memory:") # Temporary SQLite DB (Replace with actual DB connection)
110
  cursor = conn.cursor()
111
  cursor.execute(sql_query)
112
  result = cursor.fetchall()