acadiaway commited on
Commit
d339486
·
1 Parent(s): 50136a9

pipeline with the bug messages for Google API -v3 syntax

Browse files
Files changed (1) hide show
  1. pipeline.py +123 -2
pipeline.py CHANGED
@@ -14,7 +14,7 @@ def query_gemini_api(prompt, max_retries=3):
14
 
15
  # Gemini API endpoint
16
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
17
- print(f=== DEBUG: API URL: {url[:50]}...") # Partial URL for debug
18
 
19
  headers = {
20
  "Content-Type": "application/json"
@@ -67,7 +67,7 @@ def query_gemini_api(prompt, max_retries=3):
67
  if attempt == max_retries - 1:
68
  raise Exception(error_msg)
69
 
70
- except requests.exceptions.Timeout: # Fixed: Added missing colon
71
  print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
72
  if attempt == max_retries - 1:
73
  raise Exception("Request timed out after multiple attempts")
@@ -80,3 +80,124 @@ def query_gemini_api(prompt, max_retries=3):
80
  time.sleep(5)
81
 
82
  raise Exception("Failed to get response after all retries")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Gemini API endpoint
16
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
17
+ print(f"=== DEBUG: API URL: {url[:50]}...") # Fixed: Proper f-string syntax
18
 
19
  headers = {
20
  "Content-Type": "application/json"
 
67
  if attempt == max_retries - 1:
68
  raise Exception(error_msg)
69
 
70
+ except requests.exceptions.Timeout: # Previous fix retained
71
  print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
72
  if attempt == max_retries - 1:
73
  raise Exception("Request timed out after multiple attempts")
 
80
  time.sleep(5)
81
 
82
  raise Exception("Failed to get response after all retries")
83
+
84
+ def extract_user_requested_limit(nl_query):
85
+ """Extract user-requested number from natural language query"""
86
+ patterns = [
87
+ r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
88
+ r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
89
+ r'(?:names\s+of\s+)(\d+)\s+',
90
+ r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
91
+ ]
92
+
93
+ for pattern in patterns:
94
+ match = re.search(pattern, nl_query, re.IGNORECASE)
95
+ if match:
96
+ return int(match.group(1))
97
+ return None
98
+
99
+ def clean_sql_output(sql_text, user_limit=None):
100
+ """Clean and validate SQL output from the model"""
101
+ sql_text = sql_text.strip()
102
+
103
+ # Remove markdown formatting
104
+ if sql_text.startswith("```"):
105
+ lines = sql_text.split('\n')
106
+ # Find SQL content between backticks
107
+ sql_lines = []
108
+ in_sql = False
109
+ for line in lines:
110
+ if line.strip().startswith("```"):
111
+ in_sql = not in_sql
112
+ continue
113
+ if in_sql:
114
+ sql_lines.append(line)
115
+ sql_text = '\n'.join(sql_lines)
116
+
117
+ # Handle multiple lines - extract the main SELECT query
118
+ lines = sql_text.split('\n')
119
+ sql = ""
120
+ for line in lines:
121
+ line = line.strip()
122
+ if line and (line.upper().startswith('SELECT') or sql):
123
+ sql += line + " "
124
+ if line.endswith(';'):
125
+ break
126
+
127
+ if not sql:
128
+ # If no SELECT found, take the first non-empty line that looks like SQL
129
+ for line in lines:
130
+ line = line.strip()
131
+ if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
132
+ sql = line
133
+ break
134
+
135
+ sql = sql.strip().rstrip(';')
136
+
137
+ # Apply user-requested limit
138
+ if user_limit:
139
+ sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
140
+ sql += f" LIMIT {user_limit}"
141
+
142
+ return sql
143
+
144
+ def text_to_sql(nl_query):
145
+ """Convert natural language to SQL using Google Gemini"""
146
+ try:
147
+ print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
148
+
149
+ # Get database schema
150
+ try:
151
+ schema = get_schema()
152
+ print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
153
+ except Exception as e:
154
+ print(f"=== DEBUG: Schema error: {e}")
155
+ return f"Error: Database schema access failed: {str(e)}", []
156
+
157
+ # Extract user limit
158
+ user_limit = extract_user_requested_limit(nl_query)
159
+ print(f"=== DEBUG: Extracted user limit: {user_limit}")
160
+
161
+ # Create optimized prompt for Gemini
162
+ prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
163
+
164
+ Question: {nl_query}
165
+
166
+ Database Schema:
167
+ {schema[:1500]}
168
+
169
+ Requirements:
170
+ - Generate ONLY the SQL query, no explanation
171
+ - Use PostgreSQL syntax
172
+ - Be precise with table and column names from the schema
173
+ - Return a single SELECT statement
174
+
175
+ SQL Query:"""
176
+
177
+ print(f"=== DEBUG: Calling Google Gemini API...")
178
+ generated_sql = query_gemini_api(prompt)
179
+ print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
180
+
181
+ if not generated_sql or "No valid response" in generated_sql:
182
+ return "Error: No SQL generated from Gemini", []
183
+
184
+ # Clean the SQL output
185
+ sql = clean_sql_output(generated_sql, user_limit)
186
+ print(f"=== DEBUG: Final cleaned SQL: {sql}")
187
+
188
+ if not sql or not sql.upper().strip().startswith('SELECT'):
189
+ return f"Error: Invalid SQL generated: {sql}", []
190
+
191
+ # Execute SQL
192
+ print(f"=== DEBUG: Executing SQL...")
193
+ try:
194
+ results = execute_sql(sql)
195
+ print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
196
+ return sql, results
197
+ except Exception as e:
198
+ print(f"=== DEBUG: SQL execution error: {e}")
199
+ return f"Error: SQL execution failed: {str(e)}", []
200
+
201
+ except Exception as e:
202
+ print(f"=== DEBUG: General error in text_to_sql: {e}")
203
+ return f"Error: {str(e)}", []