acadiaway commited on
Commit
a474ea6
·
1 Parent(s): b4bcb5e

defog/sqlcoder-7b switch txt-2-SQL model

Browse files
Files changed (4) hide show
  1. Dockerfile +7 -3
  2. app.py +55 -9
  3. pipeline.py +116 -35
  4. requirements.txt +4 -4
Dockerfile CHANGED
@@ -17,14 +17,18 @@ RUN pip install --no-cache-dir -r requirements.txt
17
 
18
  COPY app.py pipeline.py db_utils.py ./
19
 
20
- # Set up cache directory
21
  RUN mkdir -p /tmp/cache/huggingface && \
22
  chmod -R 777 /tmp/cache/huggingface
23
 
 
24
  ENV HF_HOME=/tmp/cache/huggingface
 
 
25
  ENV PORT=8501
26
- ENV OMP_NUM_THREADS=8
 
27
 
28
  EXPOSE 8501
29
 
30
- CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  COPY app.py pipeline.py db_utils.py ./
19
 
20
+ # Set up cache directory with proper permissions
21
  RUN mkdir -p /tmp/cache/huggingface && \
22
  chmod -R 777 /tmp/cache/huggingface
23
 
24
+ # Environment variables
25
  ENV HF_HOME=/tmp/cache/huggingface
26
+ ENV TRANSFORMERS_CACHE=/tmp/cache/huggingface
27
+ ENV HF_DATASETS_CACHE=/tmp/cache/huggingface
28
  ENV PORT=8501
29
+ ENV OMP_NUM_THREADS=4
30
+ ENV TOKENIZERS_PARALLELISM=false
31
 
32
  EXPOSE 8501
33
 
34
+ CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
app.py CHANGED
@@ -1,13 +1,59 @@
1
  import streamlit as st
2
  from pipeline import text_to_sql
3
 
4
- st.title("Arctic Text-to-SQL App")
5
-
6
- nl_query = st.text_input("Enter your query:", value="List 11 names of ships type schooner")
7
- if st.button("Generate & Execute"):
8
- if nl_query:
9
- sql, results = text_to_sql(nl_query)
10
- st.write("Generated SQL:", sql)
11
- st.write("Results:", results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  else:
13
- st.error("Enter a query.")
 
1
  import streamlit as st
2
  from pipeline import text_to_sql
3
 
4
+ st.title("SQLCoder Text-to-SQL App")
5
+ st.write("Powered by defog/sqlcoder-7b-2 🚀")
6
+
7
+ # Sample queries for user guidance
8
+ st.sidebar.header("Sample Queries")
9
+ sample_queries = [
10
+ "List 11 names of ships type schooner",
11
+ "Show me the 5 oldest ships",
12
+ "What are the different types of vessels?",
13
+ "Count the number of ships by type",
14
+ "Show ships built after 1900"
15
+ ]
16
+
17
+ selected_sample = st.sidebar.selectbox("Choose a sample query:", [""] + sample_queries)
18
+
19
+ # Main input
20
+ nl_query = st.text_input(
21
+ "Enter your natural language query:",
22
+ value=selected_sample if selected_sample else "List 11 names of ships type schooner",
23
+ help="Ask questions about your database in plain English"
24
+ )
25
+
26
+ if st.button("🔄 Generate & Execute SQL"):
27
+ if nl_query.strip():
28
+ with st.spinner("Generating SQL and executing query..."):
29
+ try:
30
+ sql, results = text_to_sql(nl_query)
31
+
32
+ # Display results
33
+ st.success("Query executed successfully!")
34
+
35
+ # Show generated SQL
36
+ st.subheader("Generated SQL:")
37
+ st.code(sql, language="sql")
38
+
39
+ # Show results
40
+ st.subheader("Results:")
41
+ if results:
42
+ # Convert results to a more readable format
43
+ if isinstance(results[0], tuple):
44
+ # If results are tuples, display as table
45
+ st.write(f"Found {len(results)} rows:")
46
+ for i, row in enumerate(results[:50]): # Show first 50 rows
47
+ st.write(f"Row {i+1}: {row}")
48
+ if len(results) > 50:
49
+ st.info(f"Showing first 50 rows out of {len(results)} total results.")
50
+ else:
51
+ st.write(results)
52
+ else:
53
+ st.info("Query executed successfully but returned no results.")
54
+
55
+ except Exception as e:
56
+ st.error(f"Error: {str(e)}")
57
+ st.write("Please try rephrasing your query or check if the requested data exists in the database.")
58
  else:
59
+ st.warning("Please enter a query to proceed.")
pipeline.py CHANGED
@@ -1,51 +1,132 @@
1
  import os
2
- from transformers import AutoTokenizer
3
- from vllm import LLM, SamplingParams
4
  from db_utils import get_schema, execute_sql
5
 
6
- # Initialize model at startup
7
  model = None
8
  tokenizer = None
9
- try:
10
- tokenizer = AutoTokenizer.from_pretrained(
11
- "Snowflake/Arctic-Text2SQL-R1-7B",
12
- cache_dir="/tmp/cache/huggingface",
13
- trust_remote_code=True
14
- )
15
- model = LLM(
16
- model="Snowflake/Arctic-Text2SQL-R1-7B",
17
- dtype="float16",
18
- gpu_memory_utilization=0.75,
19
- max_model_len=1024,
20
- max_num_seqs=1,
21
- enforce_eager=True,
22
- trust_remote_code=True
23
- )
24
- except Exception as e:
25
- print(f"Error loading model at startup: {e}")
26
- raise
27
 
28
- def text_to_sql(nl_query):
 
 
 
 
 
 
29
  try:
30
- schema = get_schema()
31
- prompt = f"""### Task
32
- Generate a SQL query to answer the following natural language question: {nl_query}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  ### Database Schema
 
35
  {schema}
36
 
37
- ### Response Format
38
- Output only the SQL query.
 
 
 
 
39
  """
40
- sampling_params = SamplingParams(
41
- temperature=0,
42
- max_tokens=128,
43
- stop=["\n\n"]
44
- )
45
- outputs = model.generate([prompt], sampling_params)
46
- sql = outputs[0].outputs[0].text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  results = execute_sql(sql)
 
48
  return sql, results
 
49
  except Exception as e:
50
  print(f"Error in text_to_sql: {e}")
51
- raise
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  from db_utils import get_schema, execute_sql
5
 
6
+ # Initialize model and tokenizer as global variables
7
  model = None
8
  tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def load_model():
11
+ """Load SQLCoder model with quantization for memory efficiency"""
12
+ global model, tokenizer
13
+
14
+ if model is not None and tokenizer is not None:
15
+ return model, tokenizer
16
+
17
  try:
18
+ # Configure quantization to reduce memory usage
19
+ quantization_config = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_compute_dtype=torch.float16,
22
+ bnb_4bit_use_double_quant=True,
23
+ bnb_4bit_quant_type="nf4"
24
+ )
25
+
26
+ # Load tokenizer
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ "defog/sqlcoder-7b-2",
29
+ trust_remote_code=True,
30
+ cache_dir="/tmp/cache/huggingface"
31
+ )
32
+
33
+ # Load model with quantization
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ "defog/sqlcoder-7b-2",
36
+ quantization_config=quantization_config,
37
+ device_map="auto",
38
+ trust_remote_code=True,
39
+ torch_dtype=torch.float16,
40
+ cache_dir="/tmp/cache/huggingface"
41
+ )
42
+
43
+ print("SQLCoder model loaded successfully!")
44
+ return model, tokenizer
45
+
46
+ except Exception as e:
47
+ print(f"Error loading SQLCoder model: {e}")
48
+ raise e
49
+
50
+ def generate_sql(nl_query, schema):
51
+ """Generate SQL using SQLCoder with proper prompting"""
52
+ prompt = f"""### Task
53
+ Generate a PostgreSQL query to answer this question: {nl_query}
54
 
55
  ### Database Schema
56
+ The query will run on a database with the following schema:
57
  {schema}
58
 
59
+ ### Instructions
60
+ - Return only the SQL query, no explanation
61
+ - Use proper PostgreSQL syntax
62
+ - Include appropriate LIMIT clauses if the question asks for a specific number of results
63
+
64
+ ### SQL Query:
65
  """
66
+ return prompt
67
+
68
+ def text_to_sql(nl_query):
69
+ """Main function to convert natural language to SQL and execute it"""
70
+ try:
71
+ # Load model if not already loaded
72
+ model, tokenizer = load_model()
73
+
74
+ # Get database schema
75
+ schema = get_schema()
76
+
77
+ # Create the prompt
78
+ prompt = generate_sql(nl_query, schema)
79
+
80
+ # Tokenize input
81
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
82
+
83
+ # Move to appropriate device
84
+ device = next(model.parameters()).device
85
+ inputs = inputs.to(device)
86
+
87
+ # Generate SQL
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ inputs,
91
+ max_new_tokens=200,
92
+ num_beams=4,
93
+ temperature=0.1,
94
+ do_sample=False,
95
+ pad_token_id=tokenizer.eos_token_id,
96
+ eos_token_id=tokenizer.eos_token_id
97
+ )
98
+
99
+ # Decode the output
100
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
101
+
102
+ # Extract just the SQL part (after the prompt)
103
+ sql_start = generated_text.find("### SQL Query:") + len("### SQL Query:")
104
+ sql = generated_text[sql_start:].strip()
105
+
106
+ # Clean up the SQL (remove any extra text after the query)
107
+ sql_lines = sql.split('\n')
108
+ sql = sql_lines[0].strip() if sql_lines else sql.strip()
109
+
110
+ # Remove any trailing semicolon if present and clean
111
+ sql = sql.rstrip(';').strip()
112
+
113
+ # Basic validation
114
+ if not sql or not sql.lower().startswith('select'):
115
+ raise ValueError(f"Generated invalid SQL: {sql}")
116
+
117
+ print(f"Generated SQL: {sql}")
118
+
119
+ # Execute the SQL
120
  results = execute_sql(sql)
121
+
122
  return sql, results
123
+
124
  except Exception as e:
125
  print(f"Error in text_to_sql: {e}")
126
+ return f"Error: {str(e)}", []
127
+
128
+ # Initialize model on import (optional - can be lazy loaded)
129
+ try:
130
+ load_model()
131
+ except Exception as e:
132
+ print(f"Model will be loaded on first use due to: {e}")
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- transformers==4.56.0
2
- accelerate==1.10.1
3
  psycopg2-binary==2.9.10
4
  sqlalchemy==2.0.43
5
  python-dotenv==1.1.1
6
- vllm==0.10.1
7
  streamlit==1.39.0
8
- torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/cu121
 
1
+ transformers==4.45.2
2
+ accelerate==0.34.2
3
  psycopg2-binary==2.9.10
4
  sqlalchemy==2.0.43
5
  python-dotenv==1.1.1
6
+ torch==2.4.1
7
  streamlit==1.39.0
8
+ bitsandbytes==0.43.3