acadiaway commited on
Commit
eda83bc
·
1 Parent(s): 5356085

Switch to HF Inference API approach - eliminate model loading

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -15
  2. pipeline.py +115 -107
  3. requirements.txt +2 -5
Dockerfile CHANGED
@@ -6,28 +6,14 @@ RUN apt-get update && apt-get install -y \
6
  build-essential \
7
  libpq-dev \
8
  curl \
9
- git \
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
  COPY requirements.txt .
13
-
14
- # Upgrade pip and install dependencies
15
- RUN pip install --upgrade pip
16
- RUN pip install --no-cache-dir -r requirements.txt
17
 
18
  COPY app.py pipeline.py db_utils.py ./
19
 
20
- # Set up cache directory with proper permissions
21
- RUN mkdir -p /tmp/cache/huggingface && \
22
- chmod -R 777 /tmp/cache/huggingface
23
-
24
- # Environment variables
25
- ENV HF_HOME=/tmp/cache/huggingface
26
- ENV TRANSFORMERS_CACHE=/tmp/cache/huggingface
27
- ENV HF_DATASETS_CACHE=/tmp/cache/huggingface
28
  ENV PORT=8501
29
- ENV OMP_NUM_THREADS=4
30
- ENV TOKENIZERS_PARALLELISM=false
31
 
32
  EXPOSE 8501
33
 
 
6
  build-essential \
7
  libpq-dev \
8
  curl \
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  COPY requirements.txt .
12
+ RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
 
 
 
13
 
14
  COPY app.py pipeline.py db_utils.py ./
15
 
 
 
 
 
 
 
 
 
16
  ENV PORT=8501
 
 
17
 
18
  EXPOSE 8501
19
 
pipeline.py CHANGED
@@ -1,132 +1,140 @@
1
  import os
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
4
  from db_utils import get_schema, execute_sql
5
 
6
- # Initialize model and tokenizer as global variables
7
- model = None
8
- tokenizer = None
9
 
10
- def load_model():
11
- """Load SQLCoder model with quantization for memory efficiency"""
12
- global model, tokenizer
 
 
13
 
14
- if model is not None and tokenizer is not None:
15
- return model, tokenizer
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
- # Configure quantization to reduce memory usage
19
- quantization_config = BitsAndBytesConfig(
20
- load_in_4bit=True,
21
- bnb_4bit_compute_dtype=torch.float16,
22
- bnb_4bit_use_double_quant=True,
23
- bnb_4bit_quant_type="nf4"
24
- )
25
-
26
- # Load tokenizer
27
- tokenizer = AutoTokenizer.from_pretrained(
28
- "defog/sqlcoder-7b-2",
29
- trust_remote_code=True,
30
- cache_dir="/tmp/cache/huggingface"
31
- )
32
-
33
- # Load model with quantization
34
- model = AutoModelForCausalLM.from_pretrained(
35
- "defog/sqlcoder-7b-2",
36
- quantization_config=quantization_config,
37
- device_map="auto",
38
- trust_remote_code=True,
39
- torch_dtype=torch.float16,
40
- cache_dir="/tmp/cache/huggingface"
41
- )
42
-
43
- print("SQLCoder model loaded successfully!")
44
- return model, tokenizer
45
 
46
- except Exception as e:
47
- print(f"Error loading SQLCoder model: {e}")
48
- raise e
49
-
50
- def generate_sql(nl_query, schema):
51
- """Generate SQL using SQLCoder with proper prompting"""
52
- prompt = f"""### Task
53
  Generate a PostgreSQL query to answer this question: {nl_query}
54
 
55
  ### Database Schema
56
- The query will run on a database with the following schema:
57
  {schema}
58
 
59
  ### Instructions
60
- - Return only the SQL query, no explanation
61
- - Use proper PostgreSQL syntax
62
- - Include appropriate LIMIT clauses if the question asks for a specific number of results
63
 
64
- ### SQL Query:
65
- """
66
- return prompt
67
 
68
- def text_to_sql(nl_query):
69
- """Main function to convert natural language to SQL and execute it"""
70
- try:
71
- # Load model if not already loaded
72
- model, tokenizer = load_model()
73
-
74
- # Get database schema
75
- schema = get_schema()
76
-
77
- # Create the prompt
78
- prompt = generate_sql(nl_query, schema)
79
-
80
- # Tokenize input
81
- inputs = tokenizer.encode(prompt, return_tensors="pt")
82
 
83
- # Move to appropriate device
84
- device = next(model.parameters()).device
85
- inputs = inputs.to(device)
86
 
87
- # Generate SQL
88
- with torch.no_grad():
89
- outputs = model.generate(
90
- inputs,
91
- max_new_tokens=200,
92
- num_beams=4,
93
- temperature=0.1,
94
- do_sample=False,
95
- pad_token_id=tokenizer.eos_token_id,
96
- eos_token_id=tokenizer.eos_token_id
97
- )
98
 
99
- # Decode the output
100
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
101
-
102
- # Extract just the SQL part (after the prompt)
103
- sql_start = generated_text.find("### SQL Query:") + len("### SQL Query:")
104
- sql = generated_text[sql_start:].strip()
105
-
106
- # Clean up the SQL (remove any extra text after the query)
107
- sql_lines = sql.split('\n')
108
- sql = sql_lines[0].strip() if sql_lines else sql.strip()
109
-
110
- # Remove any trailing semicolon if present and clean
111
- sql = sql.rstrip(';').strip()
112
-
113
- # Basic validation
114
- if not sql or not sql.lower().startswith('select'):
115
- raise ValueError(f"Generated invalid SQL: {sql}")
116
 
117
  print(f"Generated SQL: {sql}")
118
-
119
- # Execute the SQL
120
  results = execute_sql(sql)
121
-
122
  return sql, results
123
 
124
  except Exception as e:
125
- print(f"Error in text_to_sql: {e}")
126
- return f"Error: {str(e)}", []
127
-
128
- # Initialize model on import (optional - can be lazy loaded)
129
- try:
130
- load_model()
131
- except Exception as e:
132
- print(f"Model will be loaded on first use due to: {e}")
 
1
  import os
2
+ import requests
3
+ import time
4
+ import re
5
  from db_utils import get_schema, execute_sql
6
 
7
+ # Hugging Face Inference API endpoint
8
+ API_URL = "https://api-inference.huggingface.co/models/defog/sqlcoder-7b-2"
 
9
 
10
+ def query_huggingface_api(prompt, max_retries=3):
11
+ """Query the Hugging Face Inference API"""
12
+ hf_token = os.getenv("HF_TOKEN")
13
+ if not hf_token:
14
+ raise ValueError("HF_TOKEN not found in environment variables. Add it to your Space secrets.")
15
 
16
+ headers = {"Authorization": f"Bearer {hf_token}"}
17
+ payload = {
18
+ "inputs": prompt,
19
+ "parameters": {
20
+ "max_new_tokens": 200,
21
+ "temperature": 0.1,
22
+ "do_sample": False,
23
+ "return_full_text": False
24
+ }
25
+ }
26
 
27
+ for attempt in range(max_retries):
28
+ try:
29
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
30
+
31
+ if response.status_code == 200:
32
+ result = response.json()
33
+ if isinstance(result, list) and len(result) > 0:
34
+ return result[0].get("generated_text", "").strip()
35
+ return str(result).strip()
36
+
37
+ elif response.status_code == 503:
38
+ wait_time = 20 * (attempt + 1)
39
+ print(f"Model loading, waiting {wait_time} seconds...")
40
+ time.sleep(wait_time)
41
+ continue
42
+
43
+ else:
44
+ error_msg = f"API Error {response.status_code}: {response.text}"
45
+ if attempt == max_retries - 1:
46
+ raise Exception(error_msg)
47
+
48
+ except requests.exceptions.Timeout:
49
+ if attempt == max_retries - 1:
50
+ raise Exception("Request timed out after multiple attempts")
51
+ time.sleep(5)
52
+
53
+ except Exception as e:
54
+ if attempt == max_retries - 1:
55
+ raise e
56
+ time.sleep(5)
57
+
58
+ raise Exception("Failed to get response after all retries")
59
+
60
+ def extract_user_requested_limit(nl_query):
61
+ """Extract user-requested number from natural language query"""
62
+ patterns = [
63
+ r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
64
+ r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
65
+ r'(?:names\s+of\s+)(\d+)\s+',
66
+ r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
67
+ ]
68
+
69
+ for pattern in patterns:
70
+ match = re.search(pattern, nl_query, re.IGNORECASE)
71
+ if match:
72
+ return int(match.group(1))
73
+ return None
74
+
75
+ def clean_sql_output(sql_text, user_limit=None):
76
+ """Clean and validate SQL output from the model"""
77
+ sql_text = sql_text.strip()
78
+
79
+ # Remove markdown formatting
80
+ if sql_text.startswith("```"):
81
+ lines = sql_text.split('\n')
82
+ sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
83
+
84
+ # Extract SQL
85
+ lines = sql_text.split('\n')
86
+ sql = ""
87
+ for line in lines:
88
+ line = line.strip()
89
+ if line and (line.upper().startswith('SELECT') or sql):
90
+ sql += line + " "
91
+ if line.endswith(';'):
92
+ break
93
+
94
+ sql = sql.strip().rstrip(';')
95
+
96
+ # Apply user-requested limit
97
+ if user_limit:
98
+ sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
99
+ sql += f" LIMIT {user_limit}"
100
+
101
+ return sql
102
+
103
+ def text_to_sql(nl_query):
104
+ """Convert natural language to SQL using Hugging Face Inference API"""
105
  try:
106
+ schema = get_schema()
107
+ user_limit = extract_user_requested_limit(nl_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ prompt = f"""### Task
 
 
 
 
 
 
110
  Generate a PostgreSQL query to answer this question: {nl_query}
111
 
112
  ### Database Schema
 
113
  {schema}
114
 
115
  ### Instructions
116
+ - Return only the SQL query
117
+ - Use PostgreSQL syntax
118
+ - Be precise with table and column names
119
 
120
+ ### SQL Query:"""
 
 
121
 
122
+ print("Querying Hugging Face Inference API...")
123
+ generated_sql = query_huggingface_api(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ if not generated_sql:
126
+ raise ValueError("No SQL generated from the model")
 
127
 
128
+ sql = clean_sql_output(generated_sql, user_limit)
 
 
 
 
 
 
 
 
 
 
129
 
130
+ if not sql or not sql.upper().startswith('SELECT'):
131
+ raise ValueError(f"Invalid SQL generated: {sql}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  print(f"Generated SQL: {sql}")
 
 
134
  results = execute_sql(sql)
 
135
  return sql, results
136
 
137
  except Exception as e:
138
+ error_msg = str(e)
139
+ print(f"Error in text_to_sql: {error_msg}")
140
+ return f"Error: {error_msg}", []
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,5 @@
1
- transformers==4.45.2
2
- accelerate==0.34.2
3
  psycopg2-binary==2.9.10
4
  sqlalchemy==2.0.43
5
  python-dotenv==1.1.1
6
- torch==2.4.1
7
- streamlit==1.39.0
8
- bitsandbytes==0.43.3
 
1
+ requests==2.31.0
 
2
  psycopg2-binary==2.9.10
3
  sqlalchemy==2.0.43
4
  python-dotenv==1.1.1
5
+ streamlit==1.39.0