File size: 4,566 Bytes
6b5f4d7
eda83bc
 
 
6b5f4d7
 
eda83bc
 
6b5f4d7
eda83bc
 
 
 
 
a474ea6
eda83bc
 
 
 
 
 
 
 
 
 
a474ea6
eda83bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b5f4d7
eda83bc
 
a474ea6
eda83bc
a474ea6
3d893a8
 
 
 
a474ea6
eda83bc
 
 
a474ea6
eda83bc
a474ea6
eda83bc
 
a474ea6
eda83bc
 
a474ea6
eda83bc
a474ea6
eda83bc
 
a474ea6
 
3d893a8
 
a474ea6
3d893a8
eda83bc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import requests
import time
import re
from db_utils import get_schema, execute_sql

# Hugging Face Inference API endpoint
API_URL = "https://api-inference.huggingface.co/models/defog/sqlcoder-7b-2"

def query_huggingface_api(prompt, max_retries=3):
    """Query the Hugging Face Inference API"""
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables. Add it to your Space secrets.")
    
    headers = {"Authorization": f"Bearer {hf_token}"}
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 200,
            "temperature": 0.1,
            "do_sample": False,
            "return_full_text": False
        }
    }
    
    for attempt in range(max_retries):
        try:
            response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
            
            if response.status_code == 200:
                result = response.json()
                if isinstance(result, list) and len(result) > 0:
                    return result[0].get("generated_text", "").strip()
                return str(result).strip()
            
            elif response.status_code == 503:
                wait_time = 20 * (attempt + 1)
                print(f"Model loading, waiting {wait_time} seconds...")
                time.sleep(wait_time)
                continue
                
            else:
                error_msg = f"API Error {response.status_code}: {response.text}"
                if attempt == max_retries - 1:
                    raise Exception(error_msg)
                    
        except requests.exceptions.Timeout:
            if attempt == max_retries - 1:
                raise Exception("Request timed out after multiple attempts")
            time.sleep(5)
            
        except Exception as e:
            if attempt == max_retries - 1:
                raise e
            time.sleep(5)
    
    raise Exception("Failed to get response after all retries")

def extract_user_requested_limit(nl_query):
    """Extract user-requested number from natural language query"""
    patterns = [
        r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
        r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
        r'(?:names\s+of\s+)(\d+)\s+',
        r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, nl_query, re.IGNORECASE)
        if match:
            return int(match.group(1))
    return None

def clean_sql_output(sql_text, user_limit=None):
    """Clean and validate SQL output from the model"""
    sql_text = sql_text.strip()
    
    # Remove markdown formatting
    if sql_text.startswith("```"):
        lines = sql_text.split('\n')
        sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
        
    # Extract SQL
    lines = sql_text.split('\n')
    sql = ""
    for line in lines:
        line = line.strip()
        if line and (line.upper().startswith('SELECT') or sql):
            sql += line + " "
        if line.endswith(';'):
            break
    
    sql = sql.strip().rstrip(';')
    
    # Apply user-requested limit
    if user_limit:
        sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
        sql += f" LIMIT {user_limit}"
    
    return sql

def text_to_sql(nl_query):
    """Convert natural language to SQL using Hugging Face Inference API"""
    try:
        schema = get_schema()
        user_limit = extract_user_requested_limit(nl_query)
        
        prompt = f"""### Task
Generate a PostgreSQL query to answer this question: {nl_query}

### Database Schema
{schema}

### Instructions
- Return only the SQL query
- Use PostgreSQL syntax
- Be precise with table and column names

### SQL Query:"""

        print("Querying Hugging Face Inference API...")
        generated_sql = query_huggingface_api(prompt)
        
        if not generated_sql:
            raise ValueError("No SQL generated from the model")
        
        sql = clean_sql_output(generated_sql, user_limit)
        
        if not sql or not sql.upper().startswith('SELECT'):
            raise ValueError(f"Invalid SQL generated: {sql}")
        
        print(f"Generated SQL: {sql}")
        results = execute_sql(sql)
        return sql, results
        
    except Exception as e:
        error_msg = str(e)
        print(f"Error in text_to_sql: {error_msg}")
        return f"Error: {error_msg}", []