File size: 7,419 Bytes
6b5f4d7
eda83bc
 
 
5f5380d
6b5f4d7
 
5f5380d
 
 
e186332
5f5380d
 
 
 
 
d339486
5f5380d
 
 
 
a474ea6
eda83bc
5f5380d
 
 
 
 
 
eda83bc
5f5380d
 
 
 
eda83bc
 
e186332
a474ea6
eda83bc
 
e186332
5f5380d
4e15631
e186332
eda83bc
 
 
4e15631
 
5f5380d
 
 
 
 
 
 
eda83bc
5f5380d
 
 
eda83bc
 
 
 
5f5380d
4e15631
eda83bc
 
 
d339486
50136a9
 
 
 
 
 
 
 
 
 
 
 
d339486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import requests
import time
import re
import json
from db_utils import get_schema, execute_sql

def query_gemini_api(prompt, max_retries=3):
    """Query the Google Gemini API"""
    api_key = os.getenv("GOOGLE_API_KEY")
    print(f"=== DEBUG: API Key Loaded: {api_key[:5]}...")  # Partial key for debug
    if not api_key:
        raise ValueError("GOOGLE_API_KEY not found in environment variables")
    
    # Gemini API endpoint
    url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
    print(f"=== DEBUG: API URL: {url[:50]}...")  # Fixed: Proper f-string syntax
    
    headers = {
        "Content-Type": "application/json"
    }
    
    payload = {
        "contents": [{
            "parts": [{
                "text": prompt
            }]
        }],
        "generationConfig": {
            "temperature": 0.1,
            "topK": 1,
            "topP": 0.8,
            "maxOutputTokens": 200,
            "stopSequences": ["```", "\n\n"]
        }
    }
    print(f"=== DEBUG: Payload: {json.dumps(payload, indent=2)}")
    
    for attempt in range(max_retries):
        try:
            print(f"=== DEBUG: Attempt {attempt + 1} of {max_retries}")
            response = requests.post(url, headers=headers, json=payload, timeout=30)
            print(f"=== DEBUG: API Response Status: {response.status_code}")
            print(f"=== DEBUG: Response Text: {response.text[:200]}...")  # Partial response
            
            if response.status_code == 200:
                result = response.json()
                print(f"=== DEBUG: API Response: {result}")
                
                if "candidates" in result and len(result["candidates"]) > 0:
                    candidate = result["candidates"][0]
                    if "content" in candidate and "parts" in candidate["content"]:
                        generated_text = candidate["content"]["parts"][0]["text"].strip()
                        return generated_text
                
                return "No valid response generated"
            
            elif response.status_code == 429:
                wait_time = 60 * (attempt + 1)  # Rate limit - wait longer
                print(f"=== DEBUG: Rate limited, waiting {wait_time} seconds...")
                time.sleep(wait_time)
                continue
                
            else:
                error_msg = f"Gemini API Error {response.status_code}: {response.text}"
                print(f"=== DEBUG: {error_msg}")
                if attempt == max_retries - 1:
                    raise Exception(error_msg)
                    
        except requests.exceptions.Timeout:  # Previous fix retained
            print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
            if attempt == max_retries - 1:
                raise Exception("Request timed out after multiple attempts")
            time.sleep(5)
            
        except Exception as e:
            print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
            if attempt == max_retries - 1:
                raise e
            time.sleep(5)
    
    raise Exception("Failed to get response after all retries")

def extract_user_requested_limit(nl_query):
    """Extract user-requested number from natural language query"""
    patterns = [
        r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
        r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
        r'(?:names\s+of\s+)(\d+)\s+',
        r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, nl_query, re.IGNORECASE)
        if match:
            return int(match.group(1))
    return None

def clean_sql_output(sql_text, user_limit=None):
    """Clean and validate SQL output from the model"""
    sql_text = sql_text.strip()
    
    # Remove markdown formatting
    if sql_text.startswith("```"):
        lines = sql_text.split('\n')
        # Find SQL content between backticks
        sql_lines = []
        in_sql = False
        for line in lines:
            if line.strip().startswith("```"):
                in_sql = not in_sql
                continue
            if in_sql:
                sql_lines.append(line)
        sql_text = '\n'.join(sql_lines)
    
    # Handle multiple lines - extract the main SELECT query
    lines = sql_text.split('\n')
    sql = ""
    for line in lines:
        line = line.strip()
        if line and (line.upper().startswith('SELECT') or sql):
            sql += line + " "
            if line.endswith(';'):
                break
    
    if not sql:
        # If no SELECT found, take the first non-empty line that looks like SQL
        for line in lines:
            line = line.strip()
            if line and any(keyword in line.upper() for keyword in ['SELECT', 'WITH', 'FROM']):
                sql = line
                break
    
    sql = sql.strip().rstrip(';')
    
    # Apply user-requested limit
    if user_limit:
        sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
        sql += f" LIMIT {user_limit}"
    
    return sql

def text_to_sql(nl_query):
    """Convert natural language to SQL using Google Gemini"""
    try:
        print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
        
        # Get database schema
        try:
            schema = get_schema()
            print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
        except Exception as e:
            print(f"=== DEBUG: Schema error: {e}")
            return f"Error: Database schema access failed: {str(e)}", []
        
        # Extract user limit
        user_limit = extract_user_requested_limit(nl_query)
        print(f"=== DEBUG: Extracted user limit: {user_limit}")
        
        # Create optimized prompt for Gemini
        prompt = f"""You are an expert PostgreSQL developer. Convert this natural language question to a precise SQL query.

Question: {nl_query}

Database Schema:
{schema[:1500]}

Requirements:
- Generate ONLY the SQL query, no explanation
- Use PostgreSQL syntax
- Be precise with table and column names from the schema
- Return a single SELECT statement

SQL Query:"""

        print(f"=== DEBUG: Calling Google Gemini API...")
        generated_sql = query_gemini_api(prompt)
        print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
        
        if not generated_sql or "No valid response" in generated_sql:
            return "Error: No SQL generated from Gemini", []
        
        # Clean the SQL output
        sql = clean_sql_output(generated_sql, user_limit)
        print(f"=== DEBUG: Final cleaned SQL: {sql}")
        
        if not sql or not sql.upper().strip().startswith('SELECT'):
            return f"Error: Invalid SQL generated: {sql}", []
        
        # Execute SQL
        print(f"=== DEBUG: Executing SQL...")
        try:
            results = execute_sql(sql)
            print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
            return sql, results
        except Exception as e:
            print(f"=== DEBUG: SQL execution error: {e}")
            return f"Error: SQL execution failed: {str(e)}", []
        
    except Exception as e:
        print(f"=== DEBUG: General error in text_to_sql: {e}")
        return f"Error: {str(e)}", []