gemini_nl2sql / pipeline.py
acadiaway's picture
Switch to HF Inference API approach - eliminate model loading
eda83bc
raw
history blame
4.57 kB
import os
import requests
import time
import re
from db_utils import get_schema, execute_sql
# Hugging Face Inference API endpoint
API_URL = "https://api-inference.huggingface.co/models/defog/sqlcoder-7b-2"
def query_huggingface_api(prompt, max_retries=3):
"""Query the Hugging Face Inference API"""
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables. Add it to your Space secrets.")
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 200,
"temperature": 0.1,
"do_sample": False,
"return_full_text": False
}
}
for attempt in range(max_retries):
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
return result[0].get("generated_text", "").strip()
return str(result).strip()
elif response.status_code == 503:
wait_time = 20 * (attempt + 1)
print(f"Model loading, waiting {wait_time} seconds...")
time.sleep(wait_time)
continue
else:
error_msg = f"API Error {response.status_code}: {response.text}"
if attempt == max_retries - 1:
raise Exception(error_msg)
except requests.exceptions.Timeout:
if attempt == max_retries - 1:
raise Exception("Request timed out after multiple attempts")
time.sleep(5)
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(5)
raise Exception("Failed to get response after all retries")
def extract_user_requested_limit(nl_query):
"""Extract user-requested number from natural language query"""
patterns = [
r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
r'(?:names\s+of\s+)(\d+)\s+',
r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
]
for pattern in patterns:
match = re.search(pattern, nl_query, re.IGNORECASE)
if match:
return int(match.group(1))
return None
def clean_sql_output(sql_text, user_limit=None):
"""Clean and validate SQL output from the model"""
sql_text = sql_text.strip()
# Remove markdown formatting
if sql_text.startswith("```"):
lines = sql_text.split('\n')
sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
# Extract SQL
lines = sql_text.split('\n')
sql = ""
for line in lines:
line = line.strip()
if line and (line.upper().startswith('SELECT') or sql):
sql += line + " "
if line.endswith(';'):
break
sql = sql.strip().rstrip(';')
# Apply user-requested limit
if user_limit:
sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE)
sql += f" LIMIT {user_limit}"
return sql
def text_to_sql(nl_query):
"""Convert natural language to SQL using Hugging Face Inference API"""
try:
schema = get_schema()
user_limit = extract_user_requested_limit(nl_query)
prompt = f"""### Task
Generate a PostgreSQL query to answer this question: {nl_query}
### Database Schema
{schema}
### Instructions
- Return only the SQL query
- Use PostgreSQL syntax
- Be precise with table and column names
### SQL Query:"""
print("Querying Hugging Face Inference API...")
generated_sql = query_huggingface_api(prompt)
if not generated_sql:
raise ValueError("No SQL generated from the model")
sql = clean_sql_output(generated_sql, user_limit)
if not sql or not sql.upper().startswith('SELECT'):
raise ValueError(f"Invalid SQL generated: {sql}")
print(f"Generated SQL: {sql}")
results = execute_sql(sql)
return sql, results
except Exception as e:
error_msg = str(e)
print(f"Error in text_to_sql: {error_msg}")
return f"Error: {error_msg}", []