Spaces:
Running
Running
use API Qwen model
Browse files- pipeline.py +74 -30
pipeline.py
CHANGED
@@ -4,14 +4,14 @@ import time
|
|
4 |
import re
|
5 |
from db_utils import get_schema, execute_sql
|
6 |
|
7 |
-
# Hugging Face Inference API endpoint
|
8 |
-
API_URL = "https://api-inference.huggingface.co/models/
|
9 |
|
10 |
def query_huggingface_api(prompt, max_retries=3):
|
11 |
"""Query the Hugging Face Inference API"""
|
12 |
hf_token = os.getenv("HF_TOKEN")
|
13 |
if not hf_token:
|
14 |
-
raise ValueError("HF_TOKEN not found in environment variables
|
15 |
|
16 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
17 |
payload = {
|
@@ -20,37 +20,48 @@ def query_huggingface_api(prompt, max_retries=3):
|
|
20 |
"max_new_tokens": 200,
|
21 |
"temperature": 0.1,
|
22 |
"do_sample": False,
|
23 |
-
"return_full_text": False
|
|
|
24 |
}
|
25 |
}
|
26 |
|
27 |
for attempt in range(max_retries):
|
28 |
try:
|
|
|
29 |
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
|
|
|
30 |
|
31 |
if response.status_code == 200:
|
32 |
result = response.json()
|
|
|
|
|
33 |
if isinstance(result, list) and len(result) > 0:
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
elif response.status_code == 503:
|
38 |
wait_time = 20 * (attempt + 1)
|
39 |
-
print(f"Model loading, waiting {wait_time} seconds...")
|
40 |
time.sleep(wait_time)
|
41 |
continue
|
42 |
|
43 |
else:
|
44 |
error_msg = f"API Error {response.status_code}: {response.text}"
|
|
|
45 |
if attempt == max_retries - 1:
|
46 |
raise Exception(error_msg)
|
47 |
|
48 |
except requests.exceptions.Timeout:
|
|
|
49 |
if attempt == max_retries - 1:
|
50 |
raise Exception("Request timed out after multiple attempts")
|
51 |
time.sleep(5)
|
52 |
|
53 |
except Exception as e:
|
|
|
54 |
if attempt == max_retries - 1:
|
55 |
raise e
|
56 |
time.sleep(5)
|
@@ -80,16 +91,24 @@ def clean_sql_output(sql_text, user_limit=None):
|
|
80 |
if sql_text.startswith("```"):
|
81 |
lines = sql_text.split('\n')
|
82 |
sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
|
83 |
-
|
84 |
-
#
|
85 |
lines = sql_text.split('\n')
|
86 |
sql = ""
|
87 |
for line in lines:
|
88 |
line = line.strip()
|
89 |
if line and (line.upper().startswith('SELECT') or sql):
|
90 |
sql += line + " "
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
sql = sql.strip().rstrip(';')
|
95 |
|
@@ -101,40 +120,65 @@ def clean_sql_output(sql_text, user_limit=None):
|
|
101 |
return sql
|
102 |
|
103 |
def text_to_sql(nl_query):
|
104 |
-
"""Convert natural language to SQL using
|
105 |
try:
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
user_limit = extract_user_requested_limit(nl_query)
|
|
|
108 |
|
109 |
-
prompt
|
110 |
-
|
|
|
111 |
|
112 |
-
|
113 |
-
{schema}
|
114 |
|
115 |
-
|
116 |
-
- Return
|
117 |
- Use PostgreSQL syntax
|
118 |
- Be precise with table and column names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
print("Querying Hugging Face Inference API...")
|
123 |
generated_sql = query_huggingface_api(prompt)
|
|
|
124 |
|
125 |
if not generated_sql:
|
126 |
-
|
127 |
|
|
|
128 |
sql = clean_sql_output(generated_sql, user_limit)
|
|
|
129 |
|
130 |
if not sql or not sql.upper().startswith('SELECT'):
|
131 |
-
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
except Exception as e:
|
138 |
-
|
139 |
-
|
140 |
-
return f"Error: {error_msg}", []
|
|
|
4 |
import re
|
5 |
from db_utils import get_schema, execute_sql
|
6 |
|
7 |
+
# Hugging Face Inference API endpoint for Qwen2.5-Coder
|
8 |
+
API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-7B-Instruct"
|
9 |
|
10 |
def query_huggingface_api(prompt, max_retries=3):
|
11 |
"""Query the Hugging Face Inference API"""
|
12 |
hf_token = os.getenv("HF_TOKEN")
|
13 |
if not hf_token:
|
14 |
+
raise ValueError("HF_TOKEN not found in environment variables")
|
15 |
|
16 |
headers = {"Authorization": f"Bearer {hf_token}"}
|
17 |
payload = {
|
|
|
20 |
"max_new_tokens": 200,
|
21 |
"temperature": 0.1,
|
22 |
"do_sample": False,
|
23 |
+
"return_full_text": False,
|
24 |
+
"stop": ["###", "\n\n"]
|
25 |
}
|
26 |
}
|
27 |
|
28 |
for attempt in range(max_retries):
|
29 |
try:
|
30 |
+
print(f"=== DEBUG: API attempt {attempt + 1}")
|
31 |
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
|
32 |
+
print(f"=== DEBUG: API Response Status: {response.status_code}")
|
33 |
|
34 |
if response.status_code == 200:
|
35 |
result = response.json()
|
36 |
+
print(f"=== DEBUG: API Response: {result}")
|
37 |
+
|
38 |
if isinstance(result, list) and len(result) > 0:
|
39 |
+
generated_text = result[0].get("generated_text", "").strip()
|
40 |
+
else:
|
41 |
+
generated_text = str(result).strip()
|
42 |
+
|
43 |
+
return generated_text
|
44 |
|
45 |
elif response.status_code == 503:
|
46 |
wait_time = 20 * (attempt + 1)
|
47 |
+
print(f"=== DEBUG: Model loading, waiting {wait_time} seconds...")
|
48 |
time.sleep(wait_time)
|
49 |
continue
|
50 |
|
51 |
else:
|
52 |
error_msg = f"API Error {response.status_code}: {response.text}"
|
53 |
+
print(f"=== DEBUG: {error_msg}")
|
54 |
if attempt == max_retries - 1:
|
55 |
raise Exception(error_msg)
|
56 |
|
57 |
except requests.exceptions.Timeout:
|
58 |
+
print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
|
59 |
if attempt == max_retries - 1:
|
60 |
raise Exception("Request timed out after multiple attempts")
|
61 |
time.sleep(5)
|
62 |
|
63 |
except Exception as e:
|
64 |
+
print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
|
65 |
if attempt == max_retries - 1:
|
66 |
raise e
|
67 |
time.sleep(5)
|
|
|
91 |
if sql_text.startswith("```"):
|
92 |
lines = sql_text.split('\n')
|
93 |
sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
|
94 |
+
|
95 |
+
# Handle multiple lines - take the SQL part
|
96 |
lines = sql_text.split('\n')
|
97 |
sql = ""
|
98 |
for line in lines:
|
99 |
line = line.strip()
|
100 |
if line and (line.upper().startswith('SELECT') or sql):
|
101 |
sql += line + " "
|
102 |
+
if line.endswith(';'):
|
103 |
+
break
|
104 |
+
|
105 |
+
if not sql:
|
106 |
+
# If no SELECT found, take the first non-empty line
|
107 |
+
for line in lines:
|
108 |
+
line = line.strip()
|
109 |
+
if line:
|
110 |
+
sql = line
|
111 |
+
break
|
112 |
|
113 |
sql = sql.strip().rstrip(';')
|
114 |
|
|
|
120 |
return sql
|
121 |
|
122 |
def text_to_sql(nl_query):
|
123 |
+
"""Convert natural language to SQL using Qwen2.5-Coder via HF Inference API"""
|
124 |
try:
|
125 |
+
print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
|
126 |
+
|
127 |
+
# Get database schema
|
128 |
+
try:
|
129 |
+
schema = get_schema()
|
130 |
+
print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
|
131 |
+
except Exception as e:
|
132 |
+
print(f"=== DEBUG: Schema error: {e}")
|
133 |
+
return f"Error: Database schema access failed: {str(e)}", []
|
134 |
+
|
135 |
+
# Extract user limit
|
136 |
user_limit = extract_user_requested_limit(nl_query)
|
137 |
+
print(f"=== DEBUG: Extracted user limit: {user_limit}")
|
138 |
|
139 |
+
# Create optimized prompt for Qwen2.5-Coder
|
140 |
+
prompt = f"""<|im_start|>system
|
141 |
+
You are an expert SQL developer. Generate PostgreSQL queries based on natural language questions.
|
142 |
|
143 |
+
Database Schema:
|
144 |
+
{schema[:1500]}
|
145 |
|
146 |
+
Rules:
|
147 |
+
- Return ONLY the SQL query
|
148 |
- Use PostgreSQL syntax
|
149 |
- Be precise with table and column names
|
150 |
+
- Do not include explanations or markdown formatting
|
151 |
+
<|im_end|>
|
152 |
+
<|im_start|>user
|
153 |
+
{nl_query}
|
154 |
+
<|im_end|>
|
155 |
+
<|im_start|>assistant
|
156 |
+
"""
|
157 |
|
158 |
+
print("=== DEBUG: Calling Qwen2.5-Coder API...")
|
|
|
|
|
159 |
generated_sql = query_huggingface_api(prompt)
|
160 |
+
print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
|
161 |
|
162 |
if not generated_sql:
|
163 |
+
return "Error: No SQL generated from the model", []
|
164 |
|
165 |
+
# Clean the SQL output
|
166 |
sql = clean_sql_output(generated_sql, user_limit)
|
167 |
+
print(f"=== DEBUG: Final cleaned SQL: {sql}")
|
168 |
|
169 |
if not sql or not sql.upper().startswith('SELECT'):
|
170 |
+
return f"Error: Invalid SQL generated: {sql}", []
|
171 |
|
172 |
+
# Execute SQL
|
173 |
+
print("=== DEBUG: Executing SQL...")
|
174 |
+
try:
|
175 |
+
results = execute_sql(sql)
|
176 |
+
print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
|
177 |
+
return sql, results
|
178 |
+
except Exception as e:
|
179 |
+
print(f"=== DEBUG: SQL execution error: {e}")
|
180 |
+
return f"Error: SQL execution failed: {str(e)}", []
|
181 |
|
182 |
except Exception as e:
|
183 |
+
print(f"=== DEBUG: General error in text_to_sql: {e}")
|
184 |
+
return f"Error: {str(e)}", []
|
|