acadiaway commited on
Commit
5f5380d
·
1 Parent(s): 4e15631

using Gemini NL2SQL

Browse files
Files changed (1) hide show
  1. pipeline.py +65 -50
pipeline.py CHANGED
@@ -2,54 +2,63 @@ import os
2
  import requests
3
  import time
4
  import re
 
5
  from db_utils import get_schema, execute_sql
6
 
7
- # Hugging Face Inference API endpoint for Qwen2.5-Coder
8
- API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-7B-Instruct"
9
-
10
- def query_huggingface_api(prompt, max_retries=3):
11
- """Query the Hugging Face Inference API"""
12
- hf_token = os.getenv("HF_TOKEN")
13
- if not hf_token:
14
- raise ValueError("HF_TOKEN not found in environment variables")
 
 
 
 
15
 
16
- headers = {"Authorization": f"Bearer {hf_token}"}
17
  payload = {
18
- "inputs": prompt,
19
- "parameters": {
20
- "max_new_tokens": 200,
 
 
 
21
  "temperature": 0.1,
22
- "do_sample": False,
23
- "return_full_text": False,
24
- "stop": ["###", "\n\n"]
 
25
  }
26
  }
27
 
28
  for attempt in range(max_retries):
29
  try:
30
- print(f"=== DEBUG: API attempt {attempt + 1}")
31
- response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
32
  print(f"=== DEBUG: API Response Status: {response.status_code}")
33
 
34
  if response.status_code == 200:
35
  result = response.json()
36
  print(f"=== DEBUG: API Response: {result}")
37
 
38
- if isinstance(result, list) and len(result) > 0:
39
- generated_text = result[0].get("generated_text", "").strip()
40
- else:
41
- generated_text = str(result).strip()
42
-
43
- return generated_text
 
44
 
45
- elif response.status_code == 503:
46
- wait_time = 20 * (attempt + 1)
47
- print(f"=== DEBUG: Model loading, waiting {wait_time} seconds...")
48
  time.sleep(wait_time)
49
  continue
50
 
51
  else:
52
- error_msg = f"API Error {response.status_code}: {response.text}"
53
  print(f"=== DEBUG: {error_msg}")
54
  if attempt == max_retries - 1:
55
  raise Exception(error_msg)
@@ -90,9 +99,18 @@ def clean_sql_output(sql_text, user_limit=None):
90
  # Remove markdown formatting
91
  if sql_text.startswith("```"):
92
  lines = sql_text.split('\n')
93
- sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
 
 
 
 
 
 
 
 
 
94
 
95
- # Handle multiple lines - take the SQL part
96
  lines = sql_text.split('\n')
97
  sql = ""
98
  for line in lines:
@@ -103,10 +121,10 @@ def clean_sql_output(sql_text, user_limit=None):
103
  break
104
 
105
  if not sql:
106
- # If no SELECT found, take the first non-empty line
107
  for line in lines:
108
  line = line.strip()
109
- if line:
110
  sql = line
111
  break
112
 
@@ -120,7 +138,7 @@ def clean_sql_output(sql_text, user_limit=None):
120
  return sql
121
 
122
  def text_to_sql(nl_query):
123
- """Convert natural language to SQL using Qwen2.5-Coder via HF Inference API"""
124
  try:
125
  print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
126
 
@@ -136,37 +154,34 @@ def text_to_sql(nl_query):
136
  user_limit = extract_user_requested_limit(nl_query)
137
  print(f"=== DEBUG: Extracted user limit: {user_limit}")
138
 
139
- # Create optimized prompt for Qwen2.5-Coder
140
- prompt = f"""<|im_start|>system
141
- You are an expert SQL developer. Generate PostgreSQL queries based on natural language questions.
 
142
 
143
  Database Schema:
144
  {schema[:1500]}
145
 
146
- Rules:
147
- - Return ONLY the SQL query
148
  - Use PostgreSQL syntax
149
- - Be precise with table and column names
150
- - Do not include explanations or markdown formatting
151
- <|im_end|>
152
- <|im_start|>user
153
- {nl_query}
154
- <|im_end|>
155
- <|im_start|>assistant
156
- """
157
 
158
- print("=== DEBUG: Calling Qwen2.5-Coder API...")
159
- generated_sql = query_huggingface_api(prompt)
160
  print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
161
 
162
- if not generated_sql:
163
- return "Error: No SQL generated from the model", []
164
 
165
  # Clean the SQL output
166
  sql = clean_sql_output(generated_sql, user_limit)
167
  print(f"=== DEBUG: Final cleaned SQL: {sql}")
168
 
169
- if not sql or not sql.upper().startswith('SELECT'):
170
  return f"Error: Invalid SQL generated: {sql}", []
171
 
172
  # Execute SQL
 
2
  import requests
3
  import time
4
  import re
5
+ import json
6
  from db_utils import get_schema, execute_sql
7
 
8
+ def query_gemini_api(prompt, max_retries=3):
9
+ """Query the Google Gemini API"""
10
+ api_key = os.getenv("GOOGLE_API_KEY")
11
+ if not api_key:
12
+ raise ValueError("GOOGLE_API_KEY not found in environment variables")
13
+
14
+ # Gemini API endpoint
15
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
16
+
17
+ headers = {
18
+ "Content-Type": "application/json"
19
+ }
20
 
 
21
  payload = {
22
+ "contents": [{
23
+ "parts": [{
24
+ "text": prompt
25
+ }]
26
+ }],
27
+ "generationConfig": {
28
  "temperature": 0.1,
29
+ "topK": 1,
30
+ "topP": 0.8,
31
+ "maxOutputTokens": 200,
32
+ "stopSequences": ["```", "\n\n"]
33
  }
34
  }
35
 
36
  for attempt in range(max_retries):
37
  try:
38
+ print(f"=== DEBUG: Gemini API attempt {attempt + 1}")
39
+ response = requests.post(url, headers=headers, json=payload, timeout=30)
40
  print(f"=== DEBUG: API Response Status: {response.status_code}")
41
 
42
  if response.status_code == 200:
43
  result = response.json()
44
  print(f"=== DEBUG: API Response: {result}")
45
 
46
+ if "candidates" in result and len(result["candidates"]) > 0:
47
+ candidate = result["candidates"][0]
48
+ if "content" in candidate and "parts" in candidate["content"]:
49
+ generated_text = candidate["content"]["parts"][0]["text"].strip()
50
+ return generated_text
51
+
52
+ return "No valid response generated"
53
 
54
+ elif response.status_code == 429:
55
+ wait_time = 60 * (attempt + 1) # Rate limit - wait longer
56
+ print(f"=== DEBUG: Rate limited, waiting {wait_time} seconds...")
57
  time.sleep(wait_time)
58
  continue
59
 
60
  else:
61
+ error_msg = f"Gemini API Error {response.status_code}: {response.text}"
62
  print(f"=== DEBUG: {error_msg}")
63
  if attempt == max_retries - 1:
64
  raise Exception(error_msg)
 
99
  # Remove markdown formatting
100
  if sql_text.startswith("```"):
101
  lines = sql_text.split('\n')
102
+ # Find SQL content between backticks
103
+ sql_lines = []
104
+ in_sql = False
105
+ for line in lines:
106
+ if line.strip().startswith("```"):
107
+ in_sql = not in_sql
108
+ continue
109
+ if in_sql:
110
+ sql_lines.append(line)
111
+ sql_text = '\n'.join(sql_lines)
112
 
113
+ # Handle multiple lines - extract the main SELECT query
114
  lines = sql_text.split('\n')
115
  sql = ""
116
  for line in lines:
 
121
  break
122
 
123
  if not sql:
124
+ # If no SELECT found, take the first non-empty line that looks like SQL
125
  for line in lines:
126
  line = line.strip()
127
+ if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
128
  sql = line
129
  break
130
 
 
138
  return sql
139
 
140
  def text_to_sql(nl_query):
141
+ """Convert natural language to SQL using Google Gemini"""
142
  try:
143
  print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
144
 
 
154
  user_limit = extract_user_requested_limit(nl_query)
155
  print(f"=== DEBUG: Extracted user limit: {user_limit}")
156
 
157
+ # Create optimized prompt for Gemini
158
+ prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
159
+
160
+ Question: {nl_query}
161
 
162
  Database Schema:
163
  {schema[:1500]}
164
 
165
+ Requirements:
166
+ - Generate ONLY the SQL query, no explanation
167
  - Use PostgreSQL syntax
168
+ - Be precise with table and column names from the schema
169
+ - Return a single SELECT statement
170
+
171
+ SQL Query:"""
 
 
 
 
172
 
173
+ print("=== DEBUG: Calling Google Gemini API...")
174
+ generated_sql = query_gemini_api(prompt)
175
  print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
176
 
177
+ if not generated_sql or "No valid response" in generated_sql:
178
+ return "Error: No SQL generated from Gemini", []
179
 
180
  # Clean the SQL output
181
  sql = clean_sql_output(generated_sql, user_limit)
182
  print(f"=== DEBUG: Final cleaned SQL: {sql}")
183
 
184
+ if not sql or not sql.upper().strip().startswith('SELECT'):
185
  return f"Error: Invalid SQL generated: {sql}", []
186
 
187
  # Execute SQL