Spaces:
Running
Running
File size: 6,445 Bytes
6b5f4d7 eda83bc 6b5f4d7 4e15631 6b5f4d7 eda83bc 4e15631 a474ea6 eda83bc 4e15631 eda83bc a474ea6 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 eda83bc 4e15631 6b5f4d7 4e15631 eda83bc 4e15631 a474ea6 4e15631 3d893a8 4e15631 3d893a8 4e15631 eda83bc 4e15631 a474ea6 4e15631 eda83bc 4e15631 a474ea6 eda83bc 4e15631 a474ea6 4e15631 eda83bc 4e15631 a474ea6 eda83bc 4e15631 a474ea6 4e15631 a474ea6 3d893a8 4e15631 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os
import requests
import time
import re
from db_utils import get_schema, execute_sql
# Hugging Face Inference API endpoint for Qwen2.5-Coder
API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-7B-Instruct"
def query_huggingface_api(prompt, max_retries=3):
"""Query the Hugging Face Inference API"""
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables")
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 200,
"temperature": 0.1,
"do_sample": False,
"return_full_text": False,
"stop": ["###", "\n\n"]
}
}
for attempt in range(max_retries):
try:
print(f"=== DEBUG: API attempt {attempt + 1}")
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
print(f"=== DEBUG: API Response Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"=== DEBUG: API Response: {result}")
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get("generated_text", "").strip()
else:
generated_text = str(result).strip()
return generated_text
elif response.status_code == 503:
wait_time = 20 * (attempt + 1)
print(f"=== DEBUG: Model loading, waiting {wait_time} seconds...")
time.sleep(wait_time)
continue
else:
error_msg = f"API Error {response.status_code}: {response.text}"
print(f"=== DEBUG: {error_msg}")
if attempt == max_retries - 1:
raise Exception(error_msg)
except requests.exceptions.Timeout:
print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
if attempt == max_retries - 1:
raise Exception("Request timed out after multiple attempts")
time.sleep(5)
except Exception as e:
print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
if attempt == max_retries - 1:
raise e
time.sleep(5)
raise Exception("Failed to get response after all retries")
def extract_user_requested_limit(nl_query):
"""Extract user-requested number from natural language query"""
patterns = [
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
r'(?:names\s+of\s+)(\d+)\s+',
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
]
for pattern in patterns:
match = re.search(pattern, nl_query, re.IGNORECASE)
if match:
return int(match.group(1))
return None
def clean_sql_output(sql_text, user_limit=None):
"""Clean and validate SQL output from the model"""
sql_text = sql_text.strip()
# Remove markdown formatting
if sql_text.startswith("```"):
lines = sql_text.split('\n')
sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
# Handle multiple lines - take the SQL part
lines = sql_text.split('\n')
sql = ""
for line in lines:
line = line.strip()
if line and (line.upper().startswith('SELECT') or sql):
sql += line + " "
if line.endswith(';'):
break
if not sql:
# If no SELECT found, take the first non-empty line
for line in lines:
line = line.strip()
if line:
sql = line
break
sql = sql.strip().rstrip(';')
# Apply user-requested limit
if user_limit:
sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
sql += f" LIMIT {user_limit}"
return sql
def text_to_sql(nl_query):
"""Convert natural language to SQL using Qwen2.5-Coder via HF Inference API"""
try:
print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
# Get database schema
try:
schema = get_schema()
print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
except Exception as e:
print(f"=== DEBUG: Schema error: {e}")
return f"Error: Database schema access failed: {str(e)}", []
# Extract user limit
user_limit = extract_user_requested_limit(nl_query)
print(f"=== DEBUG: Extracted user limit: {user_limit}")
# Create optimized prompt for Qwen2.5-Coder
prompt = f"""<|im_start|>system
You are an expert SQL developer. Generate PostgreSQL queries based on natural language questions.
Database Schema:
{schema[:1500]}
Rules:
- Return ONLY the SQL query
- Use PostgreSQL syntax
- Be precise with table and column names
- Do not include explanations or markdown formatting
<|im_end|>
<|im_start|>user
{nl_query}
<|im_end|>
<|im_start|>assistant
"""
print("=== DEBUG: Calling Qwen2.5-Coder API...")
generated_sql = query_huggingface_api(prompt)
print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
if not generated_sql:
return "Error: No SQL generated from the model", []
# Clean the SQL output
sql = clean_sql_output(generated_sql, user_limit)
print(f"=== DEBUG: Final cleaned SQL: {sql}")
if not sql or not sql.upper().startswith('SELECT'):
return f"Error: Invalid SQL generated: {sql}", []
# Execute SQL
print("=== DEBUG: Executing SQL...")
try:
results = execute_sql(sql)
print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
return sql, results
except Exception as e:
print(f"=== DEBUG: SQL execution error: {e}")
return f"Error: SQL execution failed: {str(e)}", []
except Exception as e:
print(f"=== DEBUG: General error in text_to_sql: {e}")
return f"Error: {str(e)}", [] |