File size: 6,445 Bytes
6b5f4d7
eda83bc
 
 
6b5f4d7
 
4e15631
 
6b5f4d7
eda83bc
 
 
 
4e15631
a474ea6
eda83bc
 
 
 
 
 
 
4e15631
 
eda83bc
 
a474ea6
eda83bc
 
4e15631
eda83bc
4e15631
eda83bc
 
 
4e15631
 
eda83bc
4e15631
 
 
 
 
eda83bc
 
 
4e15631
eda83bc
 
 
 
 
4e15631
eda83bc
 
 
 
4e15631
eda83bc
 
 
 
 
4e15631
eda83bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e15631
 
eda83bc
 
 
 
 
 
4e15631
 
 
 
 
 
 
 
 
 
eda83bc
 
 
 
 
 
 
 
 
 
 
4e15631
6b5f4d7
4e15631
 
 
 
 
 
 
 
 
 
 
eda83bc
4e15631
a474ea6
4e15631
 
 
3d893a8
4e15631
 
3d893a8
4e15631
 
eda83bc
 
4e15631
 
 
 
 
 
 
a474ea6
4e15631
eda83bc
4e15631
a474ea6
eda83bc
4e15631
a474ea6
4e15631
eda83bc
4e15631
a474ea6
eda83bc
4e15631
a474ea6
4e15631
 
 
 
 
 
 
 
 
a474ea6
3d893a8
4e15631
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import requests
import time
import re
from db_utils import get_schema, execute_sql

# Hugging Face Inference API endpoint for Qwen2.5-Coder
API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-7B-Instruct"

def query_huggingface_api(prompt, max_retries=3):
    """Query the Hugging Face Inference API"""
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")
    
    headers = {"Authorization": f"Bearer {hf_token}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 200,
            "temperature": 0.1,
            "do_sample": False,
            "return_full_text": False,
            "stop": ["###", "\n\n"]
        }
    }
    
    for attempt in range(max_retries):
        try:
            print(f"=== DEBUG: API attempt {attempt + 1}")
            response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
            print(f"=== DEBUG: API Response Status: {response.status_code}")
            
            if response.status_code == 200:
                result = response.json()
                print(f"=== DEBUG: API Response: {result}")
                
                if isinstance(result, list) and len(result) > 0:
                    generated_text = result[0].get("generated_text", "").strip()
                else:
                    generated_text = str(result).strip()
                    
                return generated_text
            
            elif response.status_code == 503:
                wait_time = 20 * (attempt + 1)
                print(f"=== DEBUG: Model loading, waiting {wait_time} seconds...")
                time.sleep(wait_time)
                continue
                
            else:
                error_msg = f"API Error {response.status_code}: {response.text}"
                print(f"=== DEBUG: {error_msg}")
                if attempt == max_retries - 1:
                    raise Exception(error_msg)
                    
        except requests.exceptions.Timeout:
            print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
            if attempt == max_retries - 1:
                raise Exception("Request timed out after multiple attempts")
            time.sleep(5)
            
        except Exception as e:
            print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
            if attempt == max_retries - 1:
                raise e
            time.sleep(5)
    
    raise Exception("Failed to get response after all retries")

def extract_user_requested_limit(nl_query):
    """Extract user-requested number from natural language query"""
    patterns = [
        r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
        r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
        r'(?:names\s+of\s+)(\d+)\s+',
        r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, nl_query, re.IGNORECASE)
        if match:
            return int(match.group(1))
    return None

def clean_sql_output(sql_text, user_limit=None):
    """Clean and validate SQL output from the model"""
    sql_text = sql_text.strip()
    
    # Remove markdown formatting
    if sql_text.startswith("```"):
        lines = sql_text.split('\n')
        sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
    
    # Handle multiple lines - take the SQL part
    lines = sql_text.split('\n')
    sql = ""
    for line in lines:
        line = line.strip()
        if line and (line.upper().startswith('SELECT') or sql):
            sql += line + " "
            if line.endswith(';'):
                break
    
    if not sql:
        # If no SELECT found, take the first non-empty line
        for line in lines:
            line = line.strip()
            if line:
                sql = line
                break
    
    sql = sql.strip().rstrip(';')
    
    # Apply user-requested limit
    if user_limit:
        sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
        sql += f" LIMIT {user_limit}"
    
    return sql

def text_to_sql(nl_query):
    """Convert natural language to SQL using Qwen2.5-Coder via HF Inference API"""
    try:
        print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
        
        # Get database schema
        try:
            schema = get_schema()
            print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
        except Exception as e:
            print(f"=== DEBUG: Schema error: {e}")
            return f"Error: Database schema access failed: {str(e)}", []
        
        # Extract user limit
        user_limit = extract_user_requested_limit(nl_query)
        print(f"=== DEBUG: Extracted user limit: {user_limit}")
        
        # Create optimized prompt for Qwen2.5-Coder
        prompt = f"""<|im_start|>system
You are an expert SQL developer. Generate PostgreSQL queries based on natural language questions.

Database Schema:
{schema[:1500]}

Rules:
- Return ONLY the SQL query
- Use PostgreSQL syntax
- Be precise with table and column names
- Do not include explanations or markdown formatting
<|im_end|>
<|im_start|>user
{nl_query}
<|im_end|>
<|im_start|>assistant
"""

        print("=== DEBUG: Calling Qwen2.5-Coder API...")
        generated_sql = query_huggingface_api(prompt)
        print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
        
        if not generated_sql:
            return "Error: No SQL generated from the model", []
        
        # Clean the SQL output
        sql = clean_sql_output(generated_sql, user_limit)
        print(f"=== DEBUG: Final cleaned SQL: {sql}")
        
        if not sql or not sql.upper().startswith('SELECT'):
            return f"Error: Invalid SQL generated: {sql}", []
        
        # Execute SQL
        print("=== DEBUG: Executing SQL...")
        try:
            results = execute_sql(sql)
            print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
            return sql, results
        except Exception as e:
            print(f"=== DEBUG: SQL execution error: {e}")
            return f"Error: SQL execution failed: {str(e)}", []
        
    except Exception as e:
        print(f"=== DEBUG: General error in text_to_sql: {e}")
        return f"Error: {str(e)}", []