Spaces:
Running
Running
File size: 7,419 Bytes
6b5f4d7 eda83bc 5f5380d 6b5f4d7 5f5380d e186332 5f5380d d339486 5f5380d a474ea6 eda83bc 5f5380d eda83bc 5f5380d eda83bc e186332 a474ea6 eda83bc e186332 5f5380d 4e15631 e186332 eda83bc 4e15631 5f5380d eda83bc 5f5380d eda83bc 5f5380d 4e15631 eda83bc d339486 50136a9 d339486 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import requests
import time
import re
import json
from db_utils import get_schema, execute_sql
def query_gemini_api(prompt, max_retries=3):
"""Query the Google Gemini API"""
api_key = os.getenv("GOOGLE_API_KEY")
print(f"=== DEBUG: API Key Loaded: {api_key[:5]}...") # Partial key for debug
if not api_key:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
# Gemini API endpoint
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
print(f"=== DEBUG: API URL: {url[:50]}...") # Fixed: Proper f-string syntax
headers = {
"Content-Type": "application/json"
}
payload = {
"contents": [{
"parts": [{
"text": prompt
}]
}],
"generationConfig": {
"temperature": 0.1,
"topK": 1,
"topP": 0.8,
"maxOutputTokens": 200,
"stopSequences": ["```", "\n\n"]
}
}
print(f"=== DEBUG: Payload: {json.dumps(payload, indent=2)}")
for attempt in range(max_retries):
try:
print(f"=== DEBUG: Attempt {attempt + 1} of {max_retries}")
response = requests.post(url, headers=headers, json=payload, timeout=30)
print(f"=== DEBUG: API Response Status: {response.status_code}")
print(f"=== DEBUG: Response Text: {response.text[:200]}...") # Partial response
if response.status_code == 200:
result = response.json()
print(f"=== DEBUG: API Response: {result}")
if "candidates" in result and len(result["candidates"]) > 0:
candidate = result["candidates"][0]
if "content" in candidate and "parts" in candidate["content"]:
generated_text = candidate["content"]["parts"][0]["text"].strip()
return generated_text
return "No valid response generated"
elif response.status_code == 429:
wait_time = 60 * (attempt + 1) # Rate limit - wait longer
print(f"=== DEBUG: Rate limited, waiting {wait_time} seconds...")
time.sleep(wait_time)
continue
else:
error_msg = f"Gemini API Error {response.status_code}: {response.text}"
print(f"=== DEBUG: {error_msg}")
if attempt == max_retries - 1:
raise Exception(error_msg)
except requests.exceptions.Timeout: # Previous fix retained
print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
if attempt == max_retries - 1:
raise Exception("Request timed out after multiple attempts")
time.sleep(5)
except Exception as e:
print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
if attempt == max_retries - 1:
raise e
time.sleep(5)
raise Exception("Failed to get response after all retries")
def extract_user_requested_limit(nl_query):
"""Extract user-requested number from natural language query"""
patterns = [
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
r'(?:names\s+of\s+)(\d+)\s+',
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
]
for pattern in patterns:
match = re.search(pattern, nl_query, re.IGNORECASE)
if match:
return int(match.group(1))
return None
def clean_sql_output(sql_text, user_limit=None):
"""Clean and validate SQL output from the model"""
sql_text = sql_text.strip()
# Remove markdown formatting
if sql_text.startswith("```"):
lines = sql_text.split('\n')
# Find SQL content between backticks
sql_lines = []
in_sql = False
for line in lines:
if line.strip().startswith("```"):
in_sql = not in_sql
continue
if in_sql:
sql_lines.append(line)
sql_text = '\n'.join(sql_lines)
# Handle multiple lines - extract the main SELECT query
lines = sql_text.split('\n')
sql = ""
for line in lines:
line = line.strip()
if line and (line.upper().startswith('SELECT') or sql):
sql += line + " "
if line.endswith(';'):
break
if not sql:
# If no SELECT found, take the first non-empty line that looks like SQL
for line in lines:
line = line.strip()
if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
sql = line
break
sql = sql.strip().rstrip(';')
# Apply user-requested limit
if user_limit:
sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
sql += f" LIMIT {user_limit}"
return sql
def text_to_sql(nl_query):
"""Convert natural language to SQL using Google Gemini"""
try:
print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
# Get database schema
try:
schema = get_schema()
print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
except Exception as e:
print(f"=== DEBUG: Schema error: {e}")
return f"Error: Database schema access failed: {str(e)}", []
# Extract user limit
user_limit = extract_user_requested_limit(nl_query)
print(f"=== DEBUG: Extracted user limit: {user_limit}")
# Create optimized prompt for Gemini
prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.
Question: {nl_query}
Database Schema:
{schema[:1500]}
Requirements:
- Generate ONLY the SQL query, no explanation
- Use PostgreSQL syntax
- Be precise with table and column names from the schema
- Return a single SELECT statement
SQL Query:"""
print(f"=== DEBUG: Calling Google Gemini API...")
generated_sql = query_gemini_api(prompt)
print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
if not generated_sql or "No valid response" in generated_sql:
return "Error: No SQL generated from Gemini", []
# Clean the SQL output
sql = clean_sql_output(generated_sql, user_limit)
print(f"=== DEBUG: Final cleaned SQL: {sql}")
if not sql or not sql.upper().strip().startswith('SELECT'):
return f"Error: Invalid SQL generated: {sql}", []
# Execute SQL
print(f"=== DEBUG: Executing SQL...")
try:
results = execute_sql(sql)
print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
return sql, results
except Exception as e:
print(f"=== DEBUG: SQL execution error: {e}")
return f"Error: SQL execution failed: {str(e)}", []
except Exception as e:
print(f"=== DEBUG: General error in text_to_sql: {e}")
return f"Error: {str(e)}", []
|