Roxanne-WANG commited on
Commit
dedd8a5
·
1 Parent(s): c10b742
Files changed (4) hide show
  1. app.py +1 -2
  2. build_whoosh_index.py +1 -3
  3. schema_item_filter.py +1 -1
  4. text2sql.py +147 -222
app.py CHANGED
@@ -26,7 +26,6 @@ db_schemas = {
26
  "Citizenship" text,
27
  PRIMARY KEY ("Singer_ID")
28
  );
29
-
30
  CREATE TABLE "song" (
31
  "Song_ID" int,
32
  "Title" text,
@@ -64,4 +63,4 @@ if question:
64
  # Get the model's response (in this case, SQL query or logits)
65
  response = text2sql_bot.get_response(question, db_id)
66
  st.write(f"**Database:** {db_id}")
67
- st.write(f"**Results:** {response}")
 
26
  "Citizenship" text,
27
  PRIMARY KEY ("Singer_ID")
28
  );
 
29
  CREATE TABLE "song" (
30
  "Song_ID" int,
31
  "Title" text,
 
63
  # Get the model's response (in this case, SQL query or logits)
64
  response = text2sql_bot.get_response(question, db_id)
65
  st.write(f"**Database:** {db_id}")
66
+ st.write(f"**Results:** {response}")
build_whoosh_index.py CHANGED
@@ -8,7 +8,6 @@ def extract_contents_from_db(db_path, max_len=25):
8
  """
9
  Extract all non-null, unique text values of length <= max_len
10
  from every table and column in the SQLite database.
11
-
12
  Returns:
13
  List of tuples [(doc_id, text), ...]
14
  """
@@ -45,7 +44,6 @@ def extract_contents_from_db(db_path, max_len=25):
45
  def build_index_for_db(db_id, db_path, index_root="db_contents_index"):
46
  """
47
  Build (or open) a Whoosh index for a single database.
48
-
49
  - If the index already exists in index_root/db_id, it will be opened.
50
  - Otherwise, a new index is created and populated from the SQLite file.
51
  """
@@ -89,4 +87,4 @@ if __name__ == "__main__":
89
  print(f"Building Whoosh index for {db_id}...")
90
  build_index_for_db(db_id, db_file, INDEX_ROOT)
91
 
92
- print("All indexes built successfully.")
 
8
  """
9
  Extract all non-null, unique text values of length <= max_len
10
  from every table and column in the SQLite database.
 
11
  Returns:
12
  List of tuples [(doc_id, text), ...]
13
  """
 
44
  def build_index_for_db(db_id, db_path, index_root="db_contents_index"):
45
  """
46
  Build (or open) a Whoosh index for a single database.
 
47
  - If the index already exists in index_root/db_id, it will be opened.
48
  - Otherwise, a new index is created and populated from the SQLite file.
49
  """
 
87
  print(f"Building Whoosh index for {db_id}...")
88
  build_index_for_db(db_id, db_file, INDEX_ROOT)
89
 
90
+ print("All indexes built successfully.")
schema_item_filter.py CHANGED
@@ -347,4 +347,4 @@ if __name__ == "__main__":
347
  import json
348
  dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
349
 
350
- sic.evaluate_coverage(dataset)
 
347
  import json
348
  dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
349
 
350
+ sic.evaluate_coverage(dataset)
text2sql.py CHANGED
@@ -1,255 +1,180 @@
1
- # Attribution: Original code by Ruoxin Wang
2
- # Repository: <your-repo-url>
3
-
4
- """
5
- Module: refactored_chatbot
6
- This module provides utilities for loading database schemas, extracting DDL,
7
- indexing content, and a ChatBot class to generate SQL queries from natural language.
8
- """
9
  import os
10
  import json
 
 
11
  import re
 
12
  import sqlite3
13
- import copy
14
- from tqdm import tqdm
15
 
16
- import torch
 
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
  from whoosh import index
19
- import sqlparse
20
-
21
- from utils.db_utils import (
22
- get_db_schema,
23
- check_sql_executability,
24
- get_matched_contents,
25
- get_db_schema_sequence,
26
- get_matched_content_sequence
27
- )
28
  from schema_item_filter import SchemaItemClassifierInference, filter_schema
29
 
30
-
31
- class DatabaseUtils:
32
- """
33
- Utilities for loading database comments, schemas, and DDL statements.
34
- """
35
-
36
- @staticmethod
37
- def _remove_similar_comments(names, comments):
38
- """
39
- Remove comments identical to table/column names (ignoring underscores and spaces).
40
- """
41
- filtered = []
42
- for name, comment in zip(names, comments):
43
- normalized_name = name.replace("_", "").replace(" ", "").lower()
44
- normalized_comment = comment.replace("_", "").replace(" ", "").lower()
45
- filtered.append("") if normalized_name == normalized_comment else filtered.append(comment)
46
- return filtered
47
-
48
- @staticmethod
49
- def load_db_comments(table_json_path):
50
- """
51
- Load additional comments for tables and columns from a JSON file.
52
-
53
- Args:
54
- table_json_path (str): Path to JSON file containing table and column comments.
55
-
56
- Returns:
57
- dict: Mapping from database ID to comments structure.
58
- """
59
- additional_info = json.load(open(table_json_path))
60
- db_comments = {}
61
-
62
- for db_info in additional_info:
63
- db_id = db_info["db_id"]
64
- comment_dict = {}
65
-
66
- # Process column comments
67
- original_cols = db_info["column_names_original"]
68
- col_names = [col.lower() for _, col in original_cols]
69
- col_comments = [c.lower() for _, c in db_info["column_names"]]
70
- col_comments = DatabaseUtils._remove_similar_comments(col_names, col_comments)
71
- col_table_idxs = [t_idx for t_idx, _ in original_cols]
72
-
73
- # Process table comments
74
- original_tables = db_info["table_names_original"]
75
- tbl_names = [tbl.lower() for tbl in original_tables]
76
- tbl_comments = [c.lower() for c in db_info["table_names"]]
77
- tbl_comments = DatabaseUtils._remove_similar_comments(tbl_names, tbl_comments)
78
-
79
- for idx, name in enumerate(tbl_names):
80
- comment_dict[name] = {
81
- "table_comment": tbl_comments[idx],
82
- "column_comments": {}
83
- }
84
- # Associate columns
85
- for t_idx, col_name, col_comment in zip(col_table_idxs, col_names, col_comments):
86
- if t_idx == idx:
87
- comment_dict[name]["column_comments"][col_name] = col_comment
88
-
89
- db_comments[db_id] = comment_dict
90
-
91
- return db_comments
92
-
93
- @staticmethod
94
- def get_db_schemas(db_path, tables_json):
95
- """
96
- Build a mapping from database ID to its schema representation.
97
-
98
- Args:
99
- db_path (str): Directory containing database subdirectories.
100
- tables_json (str): Path to JSON with table comments.
101
-
102
- Returns:
103
- dict: Mapping from db_id to schema object.
104
- """
105
- comments = DatabaseUtils.load_db_comments(tables_json)
106
- schemas = {}
107
- for db_id in tqdm(os.listdir(db_path), desc="Loading schemas"):
108
- sqlite_path = os.path.join(db_path, db_id, f"{db_id}.sqlite")
109
- schemas[db_id] = get_db_schema(sqlite_path, comments, db_id)
110
- return schemas
111
-
112
- @staticmethod
113
- def get_db_ddls(db_path):
114
- """
115
- Extract formatted DDL statements for all tables in each database.
116
-
117
- Args:
118
- db_path (str): Directory containing database subdirectories.
119
-
120
- Returns:
121
- dict: Mapping from db_id to its DDL string.
122
- """
123
- ddls = {}
124
- for db_id in os.listdir(db_path):
125
- conn = sqlite3.connect(os.path.join(db_path, db_id, f"{db_id}.sqlite"))
126
- cursor = conn.cursor()
127
- cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
128
- ddl_statements = []
129
-
130
- for name, raw_sql in cursor.fetchall():
131
- sql = raw_sql or ""
132
- sql = re.sub(r'--.*', '', sql).replace("\t", " ")
133
- sql = re.sub(r" +", " ", sql)
134
- formatted = sqlparse.format(
135
- sql,
136
- keyword_case="upper",
137
- identifier_case="lower",
138
- reindent_aligned=True
139
- )
140
- # Adjust spacing for readability
141
- formatted = formatted.replace(", ", ",\n ")
142
- if formatted.rstrip().endswith(";"):
143
- formatted = formatted.rstrip()[:-1] + "\n);"
144
- formatted = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", formatted)
145
- ddl_statements.append(formatted)
146
-
147
- ddls[db_id] = "\n\n".join(ddl_statements)
148
- return ddls
149
-
150
-
151
- class ChatBot:
152
- """
153
- ChatBot for generating and executing SQL queries using a causal language model.
154
- """
155
-
156
- def __init__(self, model_name: str = "seeklhy/codes-1b", device: str = "cuda:0") -> None:
157
- """
158
- Initialize the ChatBot with model and tokenizer.
159
-
160
- Args:
161
- model_name (str): HuggingFace model identifier.
162
- device (str): CUDA device string or 'cpu'.
163
- """
164
- os.environ["CUDA_VISIBLE_DEVICES"] = device
165
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
166
- self.model = AutoModelForCausalLM.from_pretrained(
167
- model_name,
168
- device_map="auto",
169
- torch_dtype=torch.float16
170
- )
171
  self.max_length = 4096
172
  self.max_new_tokens = 256
173
  self.max_prefix_length = self.max_length - self.max_new_tokens
174
 
175
- # Schema item classifier
176
- self.schema_classifier = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
177
-
178
- # Initialize content searchers
179
- self.content_searchers = {}
180
- index_dir = "db_contents_index"
181
- for db_id in os.listdir(index_dir):
182
- path = os.path.join(index_dir, db_id)
183
- if index.exists_in(path):
184
- self.content_searchers[db_id] = index.open_dir(path).searcher()
 
185
  else:
186
- raise FileNotFoundError(f"Whoosh index not found for '{db_id}' at '{path}'")
187
 
188
- # Load schemas and DDLs
189
  self.db_ids = sorted(os.listdir("databases"))
190
- self.schemas = DatabaseUtils.get_db_schemas("databases", "data/tables.json")
191
- self.ddls = DatabaseUtils.get_db_ddls("databases")
192
-
193
- def get_response(self, question: str, db_id: str) -> str:
194
- """
195
- Generate an executable SQL query for a natural language question.
196
-
197
- Args:
198
- question (str): User question in natural language.
199
- db_id (str): Identifier of the target database.
200
-
201
- Returns:
202
- str: Executable SQL query or an error message.
203
- """
204
- # Prepare data
205
- schema = copy.deepcopy(self.schemas[db_id])
206
- contents = get_matched_contents(question, self.content_searchers[db_id])
207
  data = {
208
  "text": question,
209
- "schema": schema,
210
- "matched_contents": contents
211
  }
212
- data = filter_schema(data, self.schema_classifier, top_k=6, top_m=10)
213
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
214
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
215
-
216
- prefix = (
217
- f"{data['schema_sequence']}\n"
218
- f"{data['content_sequence']}\n"
219
- f"{question}\n"
220
- )
221
-
222
- # Tokenize and ensure length limits
223
- input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix)["input_ids"]
224
  if len(input_ids) > self.max_prefix_length:
225
- input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length - 1):]
 
226
  attention_mask = [1] * len(input_ids)
227
-
228
  inputs = {
229
- "input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.model.device),
230
- "attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.model.device)
231
  }
 
232
 
233
  with torch.no_grad():
234
- outputs = self.model.generate(
235
  **inputs,
236
- max_new_tokens=self.max_new_tokens,
237
- num_beams=4,
238
- num_return_sequences=4
239
  )
240
 
241
- # Decode and choose executable SQL
242
- decoded = self.tokenizer.batch_decode(
243
- outputs[:, inputs['input_ids'].shape[1]:],
244
- skip_special_tokens=True,
245
- clean_up_tokenization_spaces=False
246
- )
247
- final_sql = None
248
- for sql in decoded:
249
- if check_sql_executability(sql, os.path.join("databases", db_id, f"{db_id}.sqlite")) is None:
250
- final_sql = sql.strip()
251
  break
252
- if not final_sql:
253
- final_sql = decoded[0].strip() or "Sorry, I cannot generate a suitable SQL query."
254
 
255
- return final_sql
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
+ import torch
4
+ import copy
5
  import re
6
+ import sqlparse
7
  import sqlite3
 
 
8
 
9
+ from tqdm import tqdm
10
+ from utils.db_utils import get_db_schema
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
  from whoosh import index
13
+ from whoosh.index import create_in
14
+ from whoosh.fields import Schema, TEXT
15
+ from whoosh.qparser import QueryParser
16
+ from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
 
 
 
 
 
17
  from schema_item_filter import SchemaItemClassifierInference, filter_schema
18
 
19
+ def remove_similar_comments(names, comments):
20
+ '''
21
+ Remove table (or column) comments that have a high degree of similarity with their names
22
+ '''
23
+ new_comments = []
24
+ for name, comment in zip(names, comments):
25
+ if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
26
+ new_comments.append("")
27
+ else:
28
+ new_comments.append(comment)
29
+
30
+ return new_comments
31
+
32
+ def load_db_comments(table_json_path):
33
+ additional_db_info = json.load(open(table_json_path))
34
+ db_comments = dict()
35
+ for db_info in additional_db_info:
36
+ comment_dict = dict()
37
+
38
+ column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
39
+ table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
40
+ column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
41
+
42
+ assert len(column_names) == len(column_comments)
43
+ column_comments = remove_similar_comments(column_names, column_comments)
44
+
45
+ table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
46
+ table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
47
+
48
+ assert len(table_names) == len(table_comments)
49
+ table_comments = remove_similar_comments(table_names, table_comments)
50
+
51
+ for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
52
+ comment_dict[table_name] = {
53
+ "table_comment": table_comment,
54
+ "column_comments": dict()
55
+ }
56
+ for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
57
+ if t_idx == table_idx:
58
+ comment_dict[table_name]["column_comments"][column_name] = column_comment
59
+
60
+ db_comments[db_info["db_id"]] = comment_dict
61
+
62
+ return db_comments
63
+
64
+ def get_db_id2schema(db_path, tables_json):
65
+ db_comments = load_db_comments(tables_json)
66
+ db_id2schema = dict()
67
+
68
+ for db_id in tqdm(os.listdir(db_path)):
69
+ db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
70
+
71
+ return db_id2schema
72
+
73
+ def get_db_id2ddl(db_path):
74
+ db_ids = os.listdir(db_path)
75
+ db_id2ddl = dict()
76
+
77
+ for db_id in db_ids:
78
+ conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
79
+ cursor = conn.cursor()
80
+ cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
81
+ tables = cursor.fetchall()
82
+ ddl = []
83
+
84
+ for table in tables:
85
+ table_name = table[0]
86
+ table_ddl = table[1]
87
+ table_ddl.replace("\t", " ")
88
+ while " " in table_ddl:
89
+ table_ddl = table_ddl.replace(" ", " ")
90
+
91
+ table_ddl = re.sub(r'--.*', '', table_ddl)
92
+ table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
93
+ table_ddl = table_ddl.replace(", ", ",\n ")
94
+
95
+ if table_ddl.endswith(";"):
96
+ table_ddl = table_ddl[:-1]
97
+ table_ddl = table_ddl[:-1] + "\n);"
98
+ table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl)
99
+
100
+ ddl.append(table_ddl)
101
+ db_id2ddl[db_id] = "\n\n".join(ddl)
102
+
103
+ return db_id2ddl
104
+
105
+ class ChatBot():
106
+ def __init__(self) -> None:
107
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108
+ model_name = "seeklhy/codes-1b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
110
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
 
 
 
 
111
  self.max_length = 4096
112
  self.max_new_tokens = 256
113
  self.max_prefix_length = self.max_length - self.max_new_tokens
114
 
115
+ # Directly loading the model from Hugging Face
116
+ self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
117
+ self.db_id2content_searcher = dict()
118
+ for db_id in os.listdir("db_contents_index"):
119
+ index_dir = os.path.join("db_contents_index", db_id)
120
+
121
+ # Open existing Whoosh index directory
122
+ if index.exists_in(index_dir):
123
+ ix = index.open_dir(index_dir)
124
+ # keep a searcher around for querying
125
+ self.db_id2content_searcher[db_id] = ix.searcher()
126
  else:
127
+ raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'")
128
 
 
129
  self.db_ids = sorted(os.listdir("databases"))
130
+ self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
131
+ self.db_id2ddl = get_db_id2ddl("databases")
132
+
133
+ def get_response(self, question, db_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  data = {
135
  "text": question,
136
+ "schema": copy.deepcopy(self.db_id2schema[db_id]),
137
+ "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
138
  }
139
+ data = filter_schema(data, self.sic, 6, 10)
140
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
141
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
142
+
143
+ prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
144
+ print(prefix_seq)
145
+
146
+ input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
 
 
 
 
147
  if len(input_ids) > self.max_prefix_length:
148
+ print("the current input sequence exceeds the max_tokens, we will truncate it.")
149
+ input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
150
  attention_mask = [1] * len(input_ids)
151
+
152
  inputs = {
153
+ "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
154
+ "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
155
  }
156
+ input_length = inputs["input_ids"].shape[1]
157
 
158
  with torch.no_grad():
159
+ generate_ids = self.model.generate(
160
  **inputs,
161
+ max_new_tokens = self.max_new_tokens,
162
+ num_beams = 4,
163
+ num_return_sequences = 4
164
  )
165
 
166
+ generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
167
+ final_generated_sql = None
168
+ for generated_sql in generated_sqls:
169
+ execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
170
+ if execution_error is None:
171
+ final_generated_sql = generated_sql
 
 
 
 
172
  break
 
 
173
 
174
+ if final_generated_sql is None:
175
+ if generated_sqls[0].strip() != "":
176
+ final_generated_sql = generated_sqls[0].strip()
177
+ else:
178
+ final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
179
+
180
+ return final_generated_sql.replace("\n", " ")