import os import requests import time import re from db_utils import get_schema, execute_sql # Hugging Face Inference API endpoint API_URL = "https://api-inference.huggingface.co/models/defog/sqlcoder-7b-2" def query_huggingface_api(prompt, max_retries=3): """Query the Hugging Face Inference API""" hf_token = os.getenv("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN not found in environment variables. Add it to your Space secrets.") headers = {"Authorization": f"Bearer {hf_token}"} payload = { "inputs": prompt, "parameters": { "max_new_tokens": 200, "temperature": 0.1, "do_sample": False, "return_full_text": False } } for attempt in range(max_retries): try: response = requests.post(API_URL, headers=headers, json=payload, timeout=30) if response.status_code == 200: result = response.json() if isinstance(result, list) and len(result) > 0: return result[0].get("generated_text", "").strip() return str(result).strip() elif response.status_code == 503: wait_time = 20 * (attempt + 1) print(f"Model loading, waiting {wait_time} seconds...") time.sleep(wait_time) continue else: error_msg = f"API Error {response.status_code}: {response.text}" if attempt == max_retries - 1: raise Exception(error_msg) except requests.exceptions.Timeout: if attempt == max_retries - 1: raise Exception("Request timed out after multiple attempts") time.sleep(5) except Exception as e: if attempt == max_retries - 1: raise e time.sleep(5) raise Exception("Failed to get response after all retries") def extract_user_requested_limit(nl_query): """Extract user-requested number from natural language query""" patterns = [ r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b', r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)', r'(?:names\s+of\s+)(\d+)\s+', r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)', ] for pattern in patterns: match = re.search(pattern, nl_query, re.IGNORECASE) if match: return int(match.group(1)) return None def clean_sql_output(sql_text, user_limit=None): """Clean and validate SQL output from the model""" sql_text = sql_text.strip() # Remove markdown formatting if sql_text.startswith("```"): lines = sql_text.split('\n') sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text # Extract SQL lines = sql_text.split('\n') sql = "" for line in lines: line = line.strip() if line and (line.upper().startswith('SELECT') or sql): sql += line + " " if line.endswith(';'): break sql = sql.strip().rstrip(';') # Apply user-requested limit if user_limit: sql = re.sub(r'\s+LIMIT\s+\d+', '', sql, flags=re.IGNORECASE) sql += f" LIMIT {user_limit}" return sql def text_to_sql(nl_query): """Convert natural language to SQL using Hugging Face Inference API""" try: schema = get_schema() user_limit = extract_user_requested_limit(nl_query) prompt = f"""### Task Generate a PostgreSQL query to answer this question: {nl_query} ### Database Schema {schema} ### Instructions - Return only the SQL query - Use PostgreSQL syntax - Be precise with table and column names ### SQL Query:""" print("Querying Hugging Face Inference API...") generated_sql = query_huggingface_api(prompt) if not generated_sql: raise ValueError("No SQL generated from the model") sql = clean_sql_output(generated_sql, user_limit) if not sql or not sql.upper().startswith('SELECT'): raise ValueError(f"Invalid SQL generated: {sql}") print(f"Generated SQL: {sql}") results = execute_sql(sql) return sql, results except Exception as e: error_msg = str(e) print(f"Error in text_to_sql: {error_msg}") return f"Error: {error_msg}", []