Spaces:
Running
Running
syntax2
Browse files- pipeline.py +0 -128
pipeline.py
DELETED
@@ -1,128 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import requests
|
3 |
-
import time
|
4 |
-
import re
|
5 |
-
import json
|
6 |
-
from db_utils import get_schema, execute_sql
|
7 |
-
|
8 |
-
def query_gemini_api(prompt, max_retries=3):
|
9 |
-
"""Query the Google Gemini API"""
|
10 |
-
api_key = os.getenv("GOOGLE_API_KEY")
|
11 |
-
if not api_key:
|
12 |
-
raise ValueError("GOOGLE_API_KEY not found in environment variables")
|
13 |
-
|
14 |
-
# Gemini API endpoint
|
15 |
-
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
|
16 |
-
|
17 |
-
headers = {
|
18 |
-
"Content-Type": "application/json"
|
19 |
-
}
|
20 |
-
|
21 |
-
payload = {
|
22 |
-
"contents": [{
|
23 |
-
"parts": [{
|
24 |
-
"text": prompt
|
25 |
-
}]
|
26 |
-
}],
|
27 |
-
"generationConfig": {
|
28 |
-
"temperature": 0.1,
|
29 |
-
"topK": 1,
|
30 |
-
"topP": 0.8,
|
31 |
-
"maxOutputTokens": 200,
|
32 |
-
"stopSequences": ["```", "\n\n"]
|
33 |
-
},
|
34 |
-
"safetySettings": [
|
35 |
-
{
|
36 |
-
"category": "HARM_CATEGORY_HARASSMENT",
|
37 |
-
"threshold": "BLOCK_NONE"
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"category": "HARM_CATEGORY_HATE_SPEECH",
|
41 |
-
"threshold": "BLOCK_NONE"
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
45 |
-
"threshold": "BLOCK_NONE"
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
49 |
-
"threshold": "BLOCK_NONE"
|
50 |
-
}
|
51 |
-
]
|
52 |
-
}
|
53 |
-
|
54 |
-
for attempt in range(max_retries):
|
55 |
-
try:
|
56 |
-
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
57 |
-
|
58 |
-
if response.status_code == 200:
|
59 |
-
result = response.json()
|
60 |
-
if "candidates" in result and len(result["candidates"]) > 0:
|
61 |
-
candidate = result["candidates"][0]
|
62 |
-
# Check for safety ratings first to see if it was blocked
|
63 |
-
if candidate.get('finishReason') == 'SAFETY':
|
64 |
-
print("=== DEBUG: Gemini response blocked due to safety settings.")
|
65 |
-
return "Error: The response was blocked by safety filters."
|
66 |
-
if "content" in candidate and "parts" in candidate["content"]:
|
67 |
-
generated_text = candidate["content"]["parts"][0]["text"].strip()
|
68 |
-
return generated_text
|
69 |
-
|
70 |
-
return "No valid response generated"
|
71 |
-
|
72 |
-
elif response.status_code == 429:
|
73 |
-
wait_time = 60 * (attempt + 1)
|
74 |
-
time.sleep(wait_time)
|
75 |
-
continue
|
76 |
-
else:
|
77 |
-
error_msg = f"Gemini API Error {response.status_code}: {response.text}"
|
78 |
-
if attempt == max_retries - 1:
|
79 |
-
raise Exception(error_msg)
|
80 |
-
|
81 |
-
except requests.exceptions.Timeout:
|
82 |
-
if attempt == max_retries - 1:
|
83 |
-
raise Exception("Request timed out after multiple attempts")
|
84 |
-
time.sleep(5)
|
85 |
-
|
86 |
-
except Exception as e:
|
87 |
-
if attempt == max_retries - 1:
|
88 |
-
raise e
|
89 |
-
time.sleep(5)
|
90 |
-
|
91 |
-
raise Exception("Failed to get response after all retries")
|
92 |
-
|
93 |
-
def extract_user_requested_limit(nl_query):
|
94 |
-
"""Extract user-requested number from natural language query"""
|
95 |
-
patterns = [
|
96 |
-
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
|
97 |
-
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
|
98 |
-
r'(?:names\s+of\s+)(\d+)\s+',
|
99 |
-
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
|
100 |
-
]
|
101 |
-
|
102 |
-
for pattern in patterns:
|
103 |
-
match = re.search(pattern, nl_query, re.IGNORECASE)
|
104 |
-
if match:
|
105 |
-
return int(match.group(1))
|
106 |
-
return None
|
107 |
-
|
108 |
-
def clean_sql_output(sql_text, user_limit=None):
|
109 |
-
"""Clean and validate SQL output from the model"""
|
110 |
-
sql_text = sql_text.strip()
|
111 |
-
|
112 |
-
if sql_text.startswith("```"):
|
113 |
-
lines = sql_text.split('\n')
|
114 |
-
sql_lines = []
|
115 |
-
in_sql = False
|
116 |
-
for line in lines:
|
117 |
-
if line.strip().startswith("```"):
|
118 |
-
in_sql = not in_sql
|
119 |
-
continue
|
120 |
-
if in_sql:
|
121 |
-
sql_lines.append(line)
|
122 |
-
sql_text = '\n'.join(sql_lines)
|
123 |
-
|
124 |
-
lines = sql_text.split('\n')
|
125 |
-
sql = ""
|
126 |
-
for line in lines:
|
127 |
-
line = line.strip()
|
128 |
-
if line and (line.upper().startswith('SELECT') or sql
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|