acadiaway commited on
Commit
e6cde35
·
1 Parent(s): e96349e
Files changed (1) hide show
  1. pipeline.py +182 -0
pipeline.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import time
4
+ import re
5
+ import json
6
+ from db_utils import get_schema, execute_sql
7
+
8
+ def query_gemini_api(prompt, max_retries=3):
9
+ """Query the Google Gemini API"""
10
+ api_key = os.getenv("GOOGLE_API_KEY")
11
+ if not api_key:
12
+ raise ValueError("GOOGLE_API_KEY not found in environment variables")
13
+
14
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
15
+
16
+ headers = {
17
+ "Content-Type": "application/json"
18
+ }
19
+
20
+ payload = {
21
+ "contents": [{
22
+ "parts": [{
23
+ "text": prompt
24
+ }]
25
+ }],
26
+ "generationConfig": {
27
+ "temperature": 0.1,
28
+ "topK": 1,
29
+ "topP": 0.8,
30
+ "maxOutputTokens": 200,
31
+ "stopSequences": ["```", "\n\n"]
32
+ },
33
+ "safetySettings": [
34
+ {
35
+ "category": "HARM_CATEGORY_HARASSMENT",
36
+ "threshold": "BLOCK_NONE"
37
+ },
38
+ {
39
+ "category": "HARM_CATEGORY_HATE_SPEECH",
40
+ "threshold": "BLOCK_NONE"
41
+ },
42
+ {
43
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
44
+ "threshold": "BLOCK_NONE"
45
+ },
46
+ {
47
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
48
+ "threshold": "BLOCK_NONE"
49
+ }
50
+ ]
51
+ }
52
+
53
+ for attempt in range(max_retries):
54
+ try:
55
+ response = requests.post(url, headers=headers, json=payload, timeout=30)
56
+
57
+ if response.status_code == 200:
58
+ result = response.json()
59
+ if "candidates" in result and len(result["candidates"]) > 0:
60
+ candidate = result["candidates"][0]
61
+ if candidate.get('finishReason') == 'SAFETY':
62
+ return "Error: The response was blocked by safety filters."
63
+ if "content" in candidate and "parts" in candidate["content"]:
64
+ generated_text = candidate["content"]["parts"][0]["text"].strip()
65
+ return generated_text
66
+
67
+ return "No valid response generated"
68
+
69
+ elif response.status_code == 429:
70
+ wait_time = 60 * (attempt + 1)
71
+ time.sleep(wait_time)
72
+ continue
73
+ else:
74
+ error_msg = f"Gemini API Error {response.status_code}: {response.text}"
75
+ if attempt == max_retries - 1:
76
+ raise Exception(error_msg)
77
+
78
+ except requests.exceptions.Timeout:
79
+ if attempt == max_retries - 1:
80
+ raise Exception("Request timed out after multiple attempts")
81
+ time.sleep(5)
82
+
83
+ except Exception as e:
84
+ if attempt == max_retries - 1:
85
+ raise e
86
+ time.sleep(5)
87
+
88
+ raise Exception("Failed to get response after all retries")
89
+
90
+ def extract_user_requested_limit(nl_query):
91
+ """Extract user-requested number from natural language query"""
92
+ patterns = [
93
+ r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
94
+ r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
95
+ r'(?:names\s+of\s+)(\d+)\s+',
96
+ r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
97
+ ]
98
+
99
+ for pattern in patterns:
100
+ match = re.search(pattern, nl_query, re.IGNORECASE)
101
+ if match:
102
+ return int(match.group(1))
103
+ return None
104
+
105
+ def clean_sql_output(sql_text, user_limit=None):
106
+ """Clean and validate SQL output from the model"""
107
+ sql_text = sql_text.strip()
108
+
109
+ if sql_text.startswith("```"):
110
+ lines = sql_text.split('\n')
111
+ sql_lines = []
112
+ in_sql = False
113
+ for line in lines:
114
+ if line.strip().startswith("```"):
115
+ in_sql = not in_sql
116
+ continue
117
+ if in_sql:
118
+ sql_lines.append(line)
119
+ sql_text = '\n'.join(sql_lines)
120
+
121
+ lines = sql_text.split('\n')
122
+ sql = ""
123
+ for line in lines:
124
+ line = line.strip()
125
+ if line and (line.upper().startswith('SELECT') or sql):
126
+ sql += line + " "
127
+ if line.endswith(';'):
128
+ break
129
+
130
+ if not sql:
131
+ for line in lines:
132
+ line = line.strip()
133
+ if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
134
+ sql = line
135
+ break
136
+
137
+ sql = sql.strip().rstrip(';')
138
+
139
+ if user_limit:
140
+ sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
141
+ sql += f" LIMIT {user_limit}"
142
+
143
+ return sql
144
+
145
+ def text_to_sql(nl_query):
146
+ """Convert natural language to SQL using Google Gemini"""
147
+ try:
148
+ schema = get_schema()
149
+ user_limit = extract_user_requested_limit(nl_query)
150
+
151
+ prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
152
+
153
+ Question: {nl_query}
154
+
155
+ Database Schema:
156
+ {schema[:1500]}
157
+
158
+ Requirements:
159
+ - Generate ONLY the SQL query, no explanation
160
+ - Use PostgreSQL syntax
161
+ - Be precise with table and column names from the schema
162
+ - Return a single SELECT statement
163
+
164
+ SQL Query:"""
165
+
166
+ generated_sql = query_gemini_api(prompt)
167
+
168
+ if not generated_sql or "No valid response" in generated_sql or "Error:" in generated_sql:
169
+ return generated_sql, []
170
+
171
+ sql = clean_sql_output(generated_sql, user_limit)
172
+
173
+ if not sql or not sql.upper().strip().startswith('SELECT'):
174
+ return f"Error: Invalid SQL generated: {sql}", []
175
+
176
+ results = execute_sql(sql)
177
+ return sql, results
178
+
179
+ except Exception as e:
180
+ return f"Error: {str(e)}", []
181
+
182
+ #--end-of-script