Spaces:
Running
Running
File size: 4,566 Bytes
6b5f4d7 eda83bc 6b5f4d7 eda83bc 6b5f4d7 eda83bc a474ea6 eda83bc a474ea6 eda83bc 6b5f4d7 eda83bc a474ea6 eda83bc a474ea6 3d893a8 a474ea6 eda83bc a474ea6 eda83bc a474ea6 eda83bc a474ea6 eda83bc a474ea6 eda83bc a474ea6 eda83bc a474ea6 3d893a8 a474ea6 3d893a8 eda83bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import requests
import time
import re
from db_utils import get_schema, execute_sql
# Hugging Face Inference API endpoint
API_URL = "https://api-inference.huggingface.co/models/defog/sqlcoder-7b-2"
def query_huggingface_api(prompt, max_retries=3):
"""Query the Hugging Face Inference API"""
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables. Add it to your Space secrets.")
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 200,
"temperature": 0.1,
"do_sample": False,
"return_full_text": False
}
}
for attempt in range(max_retries):
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
return result[0].get("generated_text", "").strip()
return str(result).strip()
elif response.status_code == 503:
wait_time = 20 * (attempt + 1)
print(f"Model loading, waiting {wait_time} seconds...")
time.sleep(wait_time)
continue
else:
error_msg = f"API Error {response.status_code}: {response.text}"
if attempt == max_retries - 1:
raise Exception(error_msg)
except requests.exceptions.Timeout:
if attempt == max_retries - 1:
raise Exception("Request timed out after multiple attempts")
time.sleep(5)
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(5)
raise Exception("Failed to get response after all retries")
def extract_user_requested_limit(nl_query):
"""Extract user-requested number from natural language query"""
patterns = [
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
r'(?:names\s+of\s+)(\d+)\s+',
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
]
for pattern in patterns:
match = re.search(pattern, nl_query, re.IGNORECASE)
if match:
return int(match.group(1))
return None
def clean_sql_output(sql_text, user_limit=None):
"""Clean and validate SQL output from the model"""
sql_text = sql_text.strip()
# Remove markdown formatting
if sql_text.startswith("```"):
lines = sql_text.split('\n')
sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
# Extract SQL
lines = sql_text.split('\n')
sql = ""
for line in lines:
line = line.strip()
if line and (line.upper().startswith('SELECT') or sql):
sql += line + " "
if line.endswith(';'):
break
sql = sql.strip().rstrip(';')
# Apply user-requested limit
if user_limit:
sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
sql += f" LIMIT {user_limit}"
return sql
def text_to_sql(nl_query):
"""Convert natural language to SQL using Hugging Face Inference API"""
try:
schema = get_schema()
user_limit = extract_user_requested_limit(nl_query)
prompt = f"""### Task
Generate a PostgreSQL query to answer this question: {nl_query}
### Database Schema
{schema}
### Instructions
- Return only the SQL query
- Use PostgreSQL syntax
- Be precise with table and column names
### SQL Query:"""
print("Querying Hugging Face Inference API...")
generated_sql = query_huggingface_api(prompt)
if not generated_sql:
raise ValueError("No SQL generated from the model")
sql = clean_sql_output(generated_sql, user_limit)
if not sql or not sql.upper().startswith('SELECT'):
raise ValueError(f"Invalid SQL generated: {sql}")
print(f"Generated SQL: {sql}")
results = execute_sql(sql)
return sql, results
except Exception as e:
error_msg = str(e)
print(f"Error in text_to_sql: {error_msg}")
return f"Error: {error_msg}", [] |