acadiaway commited on
Commit
4e15631
·
1 Parent(s): dab1d38

use API Qwen model

Browse files
Files changed (1) hide show
  1. pipeline.py +74 -30
pipeline.py CHANGED
@@ -4,14 +4,14 @@ import time
4
  import re
5
  from db_utils import get_schema, execute_sql
6
 
7
- # Hugging Face Inference API endpoint
8
- API_URL = "https://api-inference.huggingface.co/models/defog/sqlcoder-7b-2"
9
 
10
  def query_huggingface_api(prompt, max_retries=3):
11
  """Query the Hugging Face Inference API"""
12
  hf_token = os.getenv("HF_TOKEN")
13
  if not hf_token:
14
- raise ValueError("HF_TOKEN not found in environment variables. Add it to your Space secrets.")
15
 
16
  headers = {"Authorization": f"Bearer {hf_token}"}
17
  payload = {
@@ -20,37 +20,48 @@ def query_huggingface_api(prompt, max_retries=3):
20
  "max_new_tokens": 200,
21
  "temperature": 0.1,
22
  "do_sample": False,
23
- "return_full_text": False
 
24
  }
25
  }
26
 
27
  for attempt in range(max_retries):
28
  try:
 
29
  response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
 
30
 
31
  if response.status_code == 200:
32
  result = response.json()
 
 
33
  if isinstance(result, list) and len(result) > 0:
34
- return result[0].get("generated_text", "").strip()
35
- return str(result).strip()
 
 
 
36
 
37
  elif response.status_code == 503:
38
  wait_time = 20 * (attempt + 1)
39
- print(f"Model loading, waiting {wait_time} seconds...")
40
  time.sleep(wait_time)
41
  continue
42
 
43
  else:
44
  error_msg = f"API Error {response.status_code}: {response.text}"
 
45
  if attempt == max_retries - 1:
46
  raise Exception(error_msg)
47
 
48
  except requests.exceptions.Timeout:
 
49
  if attempt == max_retries - 1:
50
  raise Exception("Request timed out after multiple attempts")
51
  time.sleep(5)
52
 
53
  except Exception as e:
 
54
  if attempt == max_retries - 1:
55
  raise e
56
  time.sleep(5)
@@ -80,16 +91,24 @@ def clean_sql_output(sql_text, user_limit=None):
80
  if sql_text.startswith("```"):
81
  lines = sql_text.split('\n')
82
  sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
83
-
84
- # Extract SQL
85
  lines = sql_text.split('\n')
86
  sql = ""
87
  for line in lines:
88
  line = line.strip()
89
  if line and (line.upper().startswith('SELECT') or sql):
90
  sql += line + " "
91
- if line.endswith(';'):
92
- break
 
 
 
 
 
 
 
 
93
 
94
  sql = sql.strip().rstrip(';')
95
 
@@ -101,40 +120,65 @@ def clean_sql_output(sql_text, user_limit=None):
101
  return sql
102
 
103
  def text_to_sql(nl_query):
104
- """Convert natural language to SQL using Hugging Face Inference API"""
105
  try:
106
- schema = get_schema()
 
 
 
 
 
 
 
 
 
 
107
  user_limit = extract_user_requested_limit(nl_query)
 
108
 
109
- prompt = f"""### Task
110
- Generate a PostgreSQL query to answer this question: {nl_query}
 
111
 
112
- ### Database Schema
113
- {schema}
114
 
115
- ### Instructions
116
- - Return only the SQL query
117
  - Use PostgreSQL syntax
118
  - Be precise with table and column names
 
 
 
 
 
 
 
119
 
120
- ### SQL Query:"""
121
-
122
- print("Querying Hugging Face Inference API...")
123
  generated_sql = query_huggingface_api(prompt)
 
124
 
125
  if not generated_sql:
126
- raise ValueError("No SQL generated from the model")
127
 
 
128
  sql = clean_sql_output(generated_sql, user_limit)
 
129
 
130
  if not sql or not sql.upper().startswith('SELECT'):
131
- raise ValueError(f"Invalid SQL generated: {sql}")
132
 
133
- print(f"Generated SQL: {sql}")
134
- results = execute_sql(sql)
135
- return sql, results
 
 
 
 
 
 
136
 
137
  except Exception as e:
138
- error_msg = str(e)
139
- print(f"Error in text_to_sql: {error_msg}")
140
- return f"Error: {error_msg}", []
 
4
  import re
5
  from db_utils import get_schema, execute_sql
6
 
7
+ # Hugging Face Inference API endpoint for Qwen2.5-Coder
8
+ API_URL = "https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-7B-Instruct"
9
 
10
  def query_huggingface_api(prompt, max_retries=3):
11
  """Query the Hugging Face Inference API"""
12
  hf_token = os.getenv("HF_TOKEN")
13
  if not hf_token:
14
+ raise ValueError("HF_TOKEN not found in environment variables")
15
 
16
  headers = {"Authorization": f"Bearer {hf_token}"}
17
  payload = {
 
20
  "max_new_tokens": 200,
21
  "temperature": 0.1,
22
  "do_sample": False,
23
+ "return_full_text": False,
24
+ "stop": ["###", "\n\n"]
25
  }
26
  }
27
 
28
  for attempt in range(max_retries):
29
  try:
30
+ print(f"=== DEBUG: API attempt {attempt + 1}")
31
  response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
32
+ print(f"=== DEBUG: API Response Status: {response.status_code}")
33
 
34
  if response.status_code == 200:
35
  result = response.json()
36
+ print(f"=== DEBUG: API Response: {result}")
37
+
38
  if isinstance(result, list) and len(result) > 0:
39
+ generated_text = result[0].get("generated_text", "").strip()
40
+ else:
41
+ generated_text = str(result).strip()
42
+
43
+ return generated_text
44
 
45
  elif response.status_code == 503:
46
  wait_time = 20 * (attempt + 1)
47
+ print(f"=== DEBUG: Model loading, waiting {wait_time} seconds...")
48
  time.sleep(wait_time)
49
  continue
50
 
51
  else:
52
  error_msg = f"API Error {response.status_code}: {response.text}"
53
+ print(f"=== DEBUG: {error_msg}")
54
  if attempt == max_retries - 1:
55
  raise Exception(error_msg)
56
 
57
  except requests.exceptions.Timeout:
58
+ print(f"=== DEBUG: Timeout on attempt {attempt + 1}")
59
  if attempt == max_retries - 1:
60
  raise Exception("Request timed out after multiple attempts")
61
  time.sleep(5)
62
 
63
  except Exception as e:
64
+ print(f"=== DEBUG: Exception on attempt {attempt + 1}: {e}")
65
  if attempt == max_retries - 1:
66
  raise e
67
  time.sleep(5)
 
91
  if sql_text.startswith("```"):
92
  lines = sql_text.split('\n')
93
  sql_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else sql_text
94
+
95
+ # Handle multiple lines - take the SQL part
96
  lines = sql_text.split('\n')
97
  sql = ""
98
  for line in lines:
99
  line = line.strip()
100
  if line and (line.upper().startswith('SELECT') or sql):
101
  sql += line + " "
102
+ if line.endswith(';'):
103
+ break
104
+
105
+ if not sql:
106
+ # If no SELECT found, take the first non-empty line
107
+ for line in lines:
108
+ line = line.strip()
109
+ if line:
110
+ sql = line
111
+ break
112
 
113
  sql = sql.strip().rstrip(';')
114
 
 
120
  return sql
121
 
122
  def text_to_sql(nl_query):
123
+ """Convert natural language to SQL using Qwen2.5-Coder via HF Inference API"""
124
  try:
125
+ print(f"=== DEBUG: Starting text_to_sql with query: {nl_query}")
126
+
127
+ # Get database schema
128
+ try:
129
+ schema = get_schema()
130
+ print(f"=== DEBUG: Schema retrieved, length: {len(schema)}")
131
+ except Exception as e:
132
+ print(f"=== DEBUG: Schema error: {e}")
133
+ return f"Error: Database schema access failed: {str(e)}", []
134
+
135
+ # Extract user limit
136
  user_limit = extract_user_requested_limit(nl_query)
137
+ print(f"=== DEBUG: Extracted user limit: {user_limit}")
138
 
139
+ # Create optimized prompt for Qwen2.5-Coder
140
+ prompt = f"""<|im_start|>system
141
+ You are an expert SQL developer. Generate PostgreSQL queries based on natural language questions.
142
 
143
+ Database Schema:
144
+ {schema[:1500]}
145
 
146
+ Rules:
147
+ - Return ONLY the SQL query
148
  - Use PostgreSQL syntax
149
  - Be precise with table and column names
150
+ - Do not include explanations or markdown formatting
151
+ <|im_end|>
152
+ <|im_start|>user
153
+ {nl_query}
154
+ <|im_end|>
155
+ <|im_start|>assistant
156
+ """
157
 
158
+ print("=== DEBUG: Calling Qwen2.5-Coder API...")
 
 
159
  generated_sql = query_huggingface_api(prompt)
160
+ print(f"=== DEBUG: Generated SQL raw: {generated_sql}")
161
 
162
  if not generated_sql:
163
+ return "Error: No SQL generated from the model", []
164
 
165
+ # Clean the SQL output
166
  sql = clean_sql_output(generated_sql, user_limit)
167
+ print(f"=== DEBUG: Final cleaned SQL: {sql}")
168
 
169
  if not sql or not sql.upper().startswith('SELECT'):
170
+ return f"Error: Invalid SQL generated: {sql}", []
171
 
172
+ # Execute SQL
173
+ print("=== DEBUG: Executing SQL...")
174
+ try:
175
+ results = execute_sql(sql)
176
+ print(f"=== DEBUG: SQL executed successfully, {len(results)} results")
177
+ return sql, results
178
+ except Exception as e:
179
+ print(f"=== DEBUG: SQL execution error: {e}")
180
+ return f"Error: SQL execution failed: {str(e)}", []
181
 
182
  except Exception as e:
183
+ print(f"=== DEBUG: General error in text_to_sql: {e}")
184
+ return f"Error: {str(e)}", []