acadiaway commited on
Commit
e186332
·
1 Parent(s): 5f5380d

pipeline with the bug messages for Google API

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -135
pipeline.py CHANGED
@@ -8,11 +8,13 @@ from db_utils import get_schema, execute_sql
8
  def query_gemini_api(prompt, max_retries=3):
9
  """Query the Google Gemini API"""
10
  api_key = os.getenv("GOOGLE_API_KEY")
 
11
  if not api_key:
12
  raise ValueError("GOOGLE_API_KEY not found in environment variables")
13
 
14
  # Gemini API endpoint
15
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
 
16
 
17
  headers = {
18
  "Content-Type": "application/json"
@@ -32,12 +34,14 @@ def query_gemini_api(prompt, max_retries=3):
32
  "stopSequences": ["```", "\n\n"]
33
  }
34
  }
 
35
 
36
  for attempt in range(max_retries):
37
  try:
38
- print(f"=== DEBUG: Gemini API attempt {attempt + 1}")
39
  response = requests.post(url, headers=headers, json=payload, timeout=30)
40
  print(f"=== DEBUG: API Response Status: {response.status_code}")
 
41
 
42
  if response.status_code == 200:
43
  result = response.json()
@@ -63,137 +67,4 @@ def query_gemini_api(prompt, max_retries=3):
63
  if attempt == max_retries - 1:
64
  raise Exception(error_msg)
65
 
66
- except requests.exceptions.Timeout:
67
- print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
68
- if attempt == max_retries - 1:
69
- raise Exception("Request timed out after multiple attempts")
70
- time.sleep(5)
71
-
72
- except Exception as e:
73
- print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
74
- if attempt == max_retries - 1:
75
- raise e
76
- time.sleep(5)
77
-
78
- raise Exception("Failed to get response after all retries")
79
-
80
- def extract_user_requested_limit(nl_query):
81
- """Extract user-requested number from natural language query"""
82
- patterns = [
83
- r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
84
- r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
85
- r'(?:names\s+of\s+)(\d+)\s+',
86
- r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
87
- ]
88
-
89
- for pattern in patterns:
90
- match = re.search(pattern, nl_query, re.IGNORECASE)
91
- if match:
92
- return int(match.group(1))
93
- return None
94
-
95
- def clean_sql_output(sql_text, user_limit=None):
96
- """Clean and validate SQL output from the model"""
97
- sql_text = sql_text.strip()
98
-
99
- # Remove markdown formatting
100
- if sql_text.startswith("```"):
101
- lines = sql_text.split('\n')
102
- # Find SQL content between backticks
103
- sql_lines = []
104
- in_sql = False
105
- for line in lines:
106
- if line.strip().startswith("```"):
107
- in_sql = not in_sql
108
- continue
109
- if in_sql:
110
- sql_lines.append(line)
111
- sql_text = '\n'.join(sql_lines)
112
-
113
- # Handle multiple lines - extract the main SELECT query
114
- lines = sql_text.split('\n')
115
- sql = ""
116
- for line in lines:
117
- line = line.strip()
118
- if line and (line.upper().startswith('SELECT') or sql):
119
- sql += line + " "
120
- if line.endswith(';'):
121
- break
122
-
123
- if not sql:
124
- # If no SELECT found, take the first non-empty line that looks like SQL
125
- for line in lines:
126
- line = line.strip()
127
- if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
128
- sql = line
129
- break
130
-
131
- sql = sql.strip().rstrip(';')
132
-
133
- # Apply user-requested limit
134
- if user_limit:
135
- sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
136
- sql += f" LIMIT {user_limit}"
137
-
138
- return sql
139
-
140
- def text_to_sql(nl_query):
141
- """Convert natural language to SQL using Google Gemini"""
142
- try:
143
- print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
144
-
145
- # Get database schema
146
- try:
147
- schema = get_schema()
148
- print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
149
- except Exception as e:
150
- print(f"=== DEBUG: Schema error: {e}")
151
- return f"Error: Database schema access failed: {str(e)}", []
152
-
153
- # Extract user limit
154
- user_limit = extract_user_requested_limit(nl_query)
155
- print(f"=== DEBUG: Extracted user limit: {user_limit}")
156
-
157
- # Create optimized prompt for Gemini
158
- prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
159
-
160
- Question: {nl_query}
161
-
162
- Database Schema:
163
- {schema[:1500]}
164
-
165
- Requirements:
166
- - Generate ONLY the SQL query, no explanation
167
- - Use PostgreSQL syntax
168
- - Be precise with table and column names from the schema
169
- - Return a single SELECT statement
170
-
171
- SQL Query:"""
172
-
173
- print("=== DEBUG: Calling Google Gemini API...")
174
- generated_sql = query_gemini_api(prompt)
175
- print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
176
-
177
- if not generated_sql or "No valid response" in generated_sql:
178
- return "Error: No SQL generated from Gemini", []
179
-
180
- # Clean the SQL output
181
- sql = clean_sql_output(generated_sql, user_limit)
182
- print(f"=== DEBUG: Final cleaned SQL: {sql}")
183
-
184
- if not sql or not sql.upper().strip().startswith('SELECT'):
185
- return f"Error: Invalid SQL generated: {sql}", []
186
-
187
- # Execute SQL
188
- print("=== DEBUG: Executing SQL...")
189
- try:
190
- results = execute_sql(sql)
191
- print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
192
- return sql, results
193
- except Exception as e:
194
- print(f"=== DEBUG: SQL execution error: {e}")
195
- return f"Error: SQL execution failed: {str(e)}", []
196
-
197
- except Exception as e:
198
- print(f"=== DEBUG: General error in text_to_sql: {e}")
199
- return f"Error: {str(e)}", []
 
8
  def query_gemini_api(prompt, max_retries=3):
9
  """Query the Google Gemini API"""
10
  api_key = os.getenv("GOOGLE_API_KEY")
11
+ print(f"=== DEBUG: API Key Loaded: {api_key[:5]}...") # Partial key for debug
12
  if not api_key:
13
  raise ValueError("GOOGLE_API_KEY not found in environment variables")
14
 
15
  # Gemini API endpoint
16
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
17
+ print(f"=== DEBUG: API URL: {url[:50]}...") # Partial URL for debug
18
 
19
  headers = {
20
  "Content-Type": "application/json"
 
34
  "stopSequences": ["```", "\n\n"]
35
  }
36
  }
37
+ print(f"=== DEBUG: Payload: {json.dumps(payload, indent=2)}")
38
 
39
  for attempt in range(max_retries):
40
  try:
41
+ print(f"=== DEBUG: Attempt {attempt + 1} of {max_retries}")
42
  response = requests.post(url, headers=headers, json=payload, timeout=30)
43
  print(f"=== DEBUG: API Response Status: {response.status_code}")
44
+ print(f"=== DEBUG: Response Text: {response.text[:200]}...") # Partial response
45
 
46
  if response.status_code == 200:
47
  result = response.json()
 
67
  if attempt == max_retries - 1:
68
  raise Exception(error_msg)
69
 
70
+ except requests.exceptions.Timeout