acadiaway commited on
Commit
b6144db
·
1 Parent(s): 3ddd1c3
Files changed (1) hide show
  1. pipeline.py +6 -88
pipeline.py CHANGED
@@ -8,13 +8,11 @@ from db_utils import get_schema, execute_sql
8
  def query_gemini_api(prompt, max_retries=3):
9
  """Query the Google Gemini API"""
10
  api_key = os.getenv("GOOGLE_API_KEY")
11
- print(f"=== DEBUG: API Key Loaded: {api_key[:5]}...") # Partial key for debug
12
  if not api_key:
13
  raise ValueError("GOOGLE_API_KEY not found in environment variables")
14
 
15
  # Gemini API endpoint
16
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
17
- print(f"=== DEBUG: API URL: {url[:50]}...")
18
 
19
  headers = {
20
  "Content-Type": "application/json"
@@ -52,21 +50,19 @@ def query_gemini_api(prompt, max_retries=3):
52
  }
53
  ]
54
  }
55
- print(f"=== DEBUG: Payload: {json.dumps(payload, indent=2)}")
56
 
57
  for attempt in range(max_retries):
58
  try:
59
- print(f"=== DEBUG: Attempt {attempt + 1} of {max_retries}")
60
  response = requests.post(url, headers=headers, json=payload, timeout=30)
61
- print(f"=== DEBUG: API Response Status: {response.status_code}")
62
- print(f"=== DEBUG: Response Text: {response.text[:200]}...") # Partial response
63
 
64
  if response.status_code == 200:
65
  result = response.json()
66
- print(f"=== DEBUG: API Response: {result}")
67
-
68
  if "candidates" in result and len(result["candidates"]) > 0:
69
  candidate = result["candidates"][0]
 
 
 
 
70
  if "content" in candidate and "parts" in candidate["content"]:
71
  generated_text = candidate["content"]["parts"][0]["text"].strip()
72
  return generated_text
@@ -74,25 +70,20 @@ def query_gemini_api(prompt, max_retries=3):
74
  return "No valid response generated"
75
 
76
  elif response.status_code == 429:
77
- wait_time = 60 * (attempt + 1) # Rate limit - wait longer
78
- print(f"=== DEBUG: Rate limited, waiting {wait_time} seconds...")
79
  time.sleep(wait_time)
80
  continue
81
-
82
  else:
83
  error_msg = f"Gemini API Error {response.status_code}: {response.text}"
84
- print(f"=== DEBUG: {error_msg}")
85
  if attempt == max_retries - 1:
86
  raise Exception(error_msg)
87
 
88
  except requests.exceptions.Timeout:
89
- print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
90
  if attempt == max_retries - 1:
91
  raise Exception("Request timed out after multiple attempts")
92
  time.sleep(5)
93
 
94
  except Exception as e:
95
- print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
96
  if attempt == max_retries - 1:
97
  raise e
98
  time.sleep(5)
@@ -118,10 +109,8 @@ def clean_sql_output(sql_text, user_limit=None):
118
  """Clean and validate SQL output from the model"""
119
  sql_text = sql_text.strip()
120
 
121
- # Remove markdown formatting
122
  if sql_text.startswith("```"):
123
  lines = sql_text.split('\n')
124
- # Find SQL content between backticks
125
  sql_lines = []
126
  in_sql = False
127
  for line in lines:
@@ -132,79 +121,8 @@ def clean_sql_output(sql_text, user_limit=None):
132
  sql_lines.append(line)
133
  sql_text = '\n'.join(sql_lines)
134
 
135
- # Handle multiple lines - extract the main SELECT query
136
  lines = sql_text.split('\n')
137
  sql = ""
138
  for line in lines:
139
  line = line.strip()
140
- if line and (line.upper().startswith('SELECT') or sql):
141
- sql += line + " "
142
- if line.endswith(';'):
143
- break
144
-
145
- if not sql:
146
- # If no SELECT found, take the first non-empty line that looks like SQL
147
- for line in lines:
148
- line = line.strip()
149
- if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
150
- sql = line
151
- break
152
-
153
- sql = sql.strip().rstrip(';')
154
-
155
- # Apply user-requested limit
156
- if user_limit:
157
- sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
158
- sql += f" LIMIT {user_limit}"
159
-
160
- return sql
161
-
162
- def text_to_sql(nl_query):
163
- """Convert natural language to SQL using Google Gemini"""
164
- try:
165
- print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
166
-
167
- # Get database schema
168
- try:
169
- schema = get_schema()
170
- print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
171
- except Exception as e:
172
- print(f"=== DEBUG: Schema error: {e}")
173
- return f"Error: Database schema access failed: {str(e)}", []
174
-
175
- # Extract user limit
176
- user_limit = extract_user_requested_limit(nl_query)
177
- print(f"=== DEBUG: Extracted user limit: {user_limit}")
178
-
179
- # Create optimized prompt for Gemini
180
- prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
181
-
182
- Question: {nl_query}
183
-
184
- Database Schema:
185
- {schema[:1500]}
186
-
187
- Requirements:
188
- - Generate ONLY the SQL query, no explanation
189
- - Use PostgreSQL syntax
190
- - Be precise with table and column names from the schema
191
- - Return a single SELECT statement
192
-
193
- SQL Query:"""
194
-
195
- print(f"=== DEBUG: Calling Google Gemini API...")
196
- generated_sql = query_gemini_api(prompt)
197
- print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
198
-
199
- if not generated_sql or "No valid response" in generated_sql:
200
- return "Error: No SQL generated from Gemini", []
201
-
202
- # Clean the SQL output
203
- sql = clean_sql_output(generated_sql, user_limit)
204
- print(f"=== DEBUG: Final cleaned SQL: {sql}")
205
-
206
- if not sql or not sql.upper().strip().startswith('SELECT'):
207
- return f"Error: Invalid SQL generated: {sql}", []
208
-
209
- # Execute SQL
210
- print(f"=== DEBUG
 
8
  def query_gemini_api(prompt, max_retries=3):
9
  """Query the Google Gemini API"""
10
  api_key = os.getenv("GOOGLE_API_KEY")
 
11
  if not api_key:
12
  raise ValueError("GOOGLE_API_KEY not found in environment variables")
13
 
14
  # Gemini API endpoint
15
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
 
16
 
17
  headers = {
18
  "Content-Type": "application/json"
 
50
  }
51
  ]
52
  }
 
53
 
54
  for attempt in range(max_retries):
55
  try:
 
56
  response = requests.post(url, headers=headers, json=payload, timeout=30)
 
 
57
 
58
  if response.status_code == 200:
59
  result = response.json()
 
 
60
  if "candidates" in result and len(result["candidates"]) > 0:
61
  candidate = result["candidates"][0]
62
+ # Check for safety ratings first to see if it was blocked
63
+ if candidate.get('finishReason') == 'SAFETY':
64
+ print("=== DEBUG: Gemini response blocked due to safety settings.")
65
+ return "Error: The response was blocked by safety filters."
66
  if "content" in candidate and "parts" in candidate["content"]:
67
  generated_text = candidate["content"]["parts"][0]["text"].strip()
68
  return generated_text
 
70
  return "No valid response generated"
71
 
72
  elif response.status_code == 429:
73
+ wait_time = 60 * (attempt + 1)
 
74
  time.sleep(wait_time)
75
  continue
 
76
  else:
77
  error_msg = f"Gemini API Error {response.status_code}: {response.text}"
 
78
  if attempt == max_retries - 1:
79
  raise Exception(error_msg)
80
 
81
  except requests.exceptions.Timeout:
 
82
  if attempt == max_retries - 1:
83
  raise Exception("Request timed out after multiple attempts")
84
  time.sleep(5)
85
 
86
  except Exception as e:
 
87
  if attempt == max_retries - 1:
88
  raise e
89
  time.sleep(5)
 
109
  """Clean and validate SQL output from the model"""
110
  sql_text = sql_text.strip()
111
 
 
112
  if sql_text.startswith("```"):
113
  lines = sql_text.split('\n')
 
114
  sql_lines = []
115
  in_sql = False
116
  for line in lines:
 
121
  sql_lines.append(line)
122
  sql_text = '\n'.join(sql_lines)
123
 
 
124
  lines = sql_text.split('\n')
125
  sql = ""
126
  for line in lines:
127
  line = line.strip()
128
+ if line and (line.upper().startswith('SELECT') or sql