Spaces:
Running
Running
import os | |
import requests | |
import time | |
import re | |
from db_utils import get_schema, execute_sql | |
# Hugging Face Inference API endpoint for Qwen2.5-Coder | |
API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-7B-Instruct" | |
def query_huggingface_api(prompt, max_retries=3): | |
"""Query the Hugging Face Inference API""" | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN not found in environment variables") | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"max_new_tokens": 200, | |
"temperature": 0.1, | |
"do_sample": False, | |
"return_full_text": False, | |
"stop": ["###", "\n\n"] | |
} | |
} | |
for attempt in range(max_retries): | |
try: | |
print(f"=== DEBUG: API attempt {attempt + 1}") | |
response = requests.post(API_URL, headers=headers, json=payload, timeout=30) | |
print(f"=== DEBUG: API Response Status: {response.status_code}") | |
if response.status_code == 200: | |
result = response.json() | |
print(f"=== DEBUG: API Response: {result}") | |
if isinstance(result, list) and len(result) > 0: | |
generated_text = result[0].get("generated_text", "").strip() | |
else: | |
generated_text = str(result).strip() | |
return generated_text | |
elif response.status_code == 503: | |
wait_time = 20 * (attempt + 1) | |
print(f"=== DEBUG: Model loading, waiting {wait_time} seconds...") | |
time.sleep(wait_time) | |
continue | |
else: | |
error_msg = f"API Error {response.status_code}: {response.text}" | |
print(f"=== DEBUG: {error_msg}") | |
if attempt == max_retries - 1: | |
raise Exception(error_msg) | |
except requests.exceptions.Timeout: | |
print(f"=== DEBUG: Timeout on attempt {attempt + 1}") | |
if attempt == max_retries - 1: | |
raise Exception("Request timed out after multiple attempts") | |
time.sleep(5) | |
except Exception as e: | |
print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}") | |
if attempt == max_retries - 1: | |
raise e | |
time.sleep(5) | |
raise Exception("Failed to get response after all retries") | |
def extract_user_requested_limit(nl_query): | |
"""Extract user-requested number from natural language query""" | |
patterns = [ | |
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b', | |
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)', | |
r'(?:names\s+of\s+)(\d+)\s+', | |
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)', | |
] | |
for pattern in patterns: | |
match = re.search(pattern, nl_query, re.IGNORECASE) | |
if match: | |
return int(match.group(1)) | |
return None | |
def clean_sql_output(sql_text, user_limit=None): | |
"""Clean and validate SQL output from the model""" | |
sql_text = sql_text.strip() | |
# Remove markdown formatting | |
if sql_text.startswith("```"): | |
lines = sql_text.split('\n') | |
sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text | |
# Handle multiple lines - take the SQL part | |
lines = sql_text.split('\n') | |
sql = "" | |
for line in lines: | |
line = line.strip() | |
if line and (line.upper().startswith('SELECT') or sql): | |
sql += line + " " | |
if line.endswith(';'): | |
break | |
if not sql: | |
# If no SELECT found, take the first non-empty line | |
for line in lines: | |
line = line.strip() | |
if line: | |
sql = line | |
break | |
sql = sql.strip().rstrip(';') | |
# Apply user-requested limit | |
if user_limit: | |
sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE) | |
sql += f" LIMIT {user_limit}" | |
return sql | |
def text_to_sql(nl_query): | |
"""Convert natural language to SQL using Qwen2.5-Coder via HF Inference API""" | |
try: | |
print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}") | |
# Get database schema | |
try: | |
schema = get_schema() | |
print(f"=== DEBUG: Schema retrieved, length: {len(schema)}") | |
except Exception as e: | |
print(f"=== DEBUG: Schema error: {e}") | |
return f"Error: Database schema access failed: {str(e)}", [] | |
# Extract user limit | |
user_limit = extract_user_requested_limit(nl_query) | |
print(f"=== DEBUG: Extracted user limit: {user_limit}") | |
# Create optimized prompt for Qwen2.5-Coder | |
prompt = f"""<|im_start|>system | |
You are an expert SQL developer. Generate PostgreSQL queries based on natural language questions. | |
Database Schema: | |
{schema[:1500]} | |
Rules: | |
- Return ONLY the SQL query | |
- Use PostgreSQL syntax | |
- Be precise with table and column names | |
- Do not include explanations or markdown formatting | |
<|im_end|> | |
<|im_start|>user | |
{nl_query} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("=== DEBUG: Calling Qwen2.5-Coder API...") | |
generated_sql = query_huggingface_api(prompt) | |
print(f"=== DEBUG: Generated SQL raw: {generated_sql}") | |
if not generated_sql: | |
return "Error: No SQL generated from the model", [] | |
# Clean the SQL output | |
sql = clean_sql_output(generated_sql, user_limit) | |
print(f"=== DEBUG: Final cleaned SQL: {sql}") | |
if not sql or not sql.upper().startswith('SELECT'): | |
return f"Error: Invalid SQL generated: {sql}", [] | |
# Execute SQL | |
print("=== DEBUG: Executing SQL...") | |
try: | |
results = execute_sql(sql) | |
print(f"=== DEBUG: SQL executed successfully, {len(results)} results") | |
return sql, results | |
except Exception as e: | |
print(f"=== DEBUG: SQL execution error: {e}") | |
return f"Error: SQL execution failed: {str(e)}", [] | |
except Exception as e: | |
print(f"=== DEBUG: General error in text_to_sql: {e}") | |
return f"Error: {str(e)}", [] |