Spaces:
Paused
Paused
# Attribution: Original code by Ruoxin Wang | |
# Repository: <your-repo-url> | |
""" | |
Module: refactored_chatbot | |
This module provides utilities for loading database schemas, extracting DDL, | |
indexing content, and a ChatBot class to generate SQL queries from natural language. | |
""" | |
import os | |
import json | |
import re | |
import sqlite3 | |
import copy | |
from tqdm import tqdm | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from whoosh import index | |
from utils.db_utils import ( | |
get_db_schema, | |
check_sql_executability, | |
get_matched_contents, | |
get_db_schema_sequence, | |
get_matched_content_sequence | |
) | |
from schema_item_filter import SchemaItemClassifierInference, filter_schema | |
class DatabaseUtils: | |
""" | |
Utilities for loading database comments, schemas, and DDL statements. | |
""" | |
def _remove_similar_comments(names, comments): | |
""" | |
Remove comments identical to table/column names (ignoring underscores and spaces). | |
""" | |
filtered = [] | |
for name, comment in zip(names, comments): | |
normalized_name = name.replace("_", "").replace(" ", "").lower() | |
normalized_comment = comment.replace("_", "").replace(" ", "").lower() | |
filtered.append("") if normalized_name == normalized_comment else filtered.append(comment) | |
return filtered | |
def load_db_comments(table_json_path): | |
""" | |
Load additional comments for tables and columns from a JSON file. | |
Args: | |
table_json_path (str): Path to JSON file containing table and column comments. | |
Returns: | |
dict: Mapping from database ID to comments structure. | |
""" | |
additional_info = json.load(open(table_json_path)) | |
db_comments = {} | |
for db_info in additional_info: | |
db_id = db_info["db_id"] | |
comment_dict = {} | |
# Process column comments | |
original_cols = db_info["column_names_original"] | |
col_names = [col.lower() for _, col in original_cols] | |
col_comments = [c.lower() for _, c in db_info["column_names"]] | |
col_comments = DatabaseUtils._remove_similar_comments(col_names, col_comments) | |
col_table_idxs = [t_idx for t_idx, _ in original_cols] | |
# Process table comments | |
original_tables = db_info["table_names_original"] | |
tbl_names = [tbl.lower() for tbl in original_tables] | |
tbl_comments = [c.lower() for c in db_info["table_names"]] | |
tbl_comments = DatabaseUtils._remove_similar_comments(tbl_names, tbl_comments) | |
for idx, name in enumerate(tbl_names): | |
comment_dict[name] = { | |
"table_comment": tbl_comments[idx], | |
"column_comments": {} | |
} | |
# Associate columns | |
for t_idx, col_name, col_comment in zip(col_table_idxs, col_names, col_comments): | |
if t_idx == idx: | |
comment_dict[name]["column_comments"][col_name] = col_comment | |
db_comments[db_id] = comment_dict | |
return db_comments | |
def get_db_schemas(db_path, tables_json): | |
""" | |
Build a mapping from database ID to its schema representation. | |
Args: | |
db_path (str): Directory containing database subdirectories. | |
tables_json (str): Path to JSON with table comments. | |
Returns: | |
dict: Mapping from db_id to schema object. | |
""" | |
comments = DatabaseUtils.load_db_comments(tables_json) | |
schemas = {} | |
for db_id in tqdm(os.listdir(db_path), desc="Loading schemas"): | |
sqlite_path = os.path.join(db_path, db_id, f"{db_id}.sqlite") | |
schemas[db_id] = get_db_schema(sqlite_path, comments, db_id) | |
return schemas | |
def get_db_ddls(db_path): | |
""" | |
Extract formatted DDL statements for all tables in each database. | |
Args: | |
db_path (str): Directory containing database subdirectories. | |
Returns: | |
dict: Mapping from db_id to its DDL string. | |
""" | |
ddls = {} | |
for db_id in os.listdir(db_path): | |
conn = sqlite3.connect(os.path.join(db_path, db_id, f"{db_id}.sqlite")) | |
cursor = conn.cursor() | |
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';") | |
ddl_statements = [] | |
for name, raw_sql in cursor.fetchall(): | |
sql = raw_sql or "" | |
sql = re.sub(r'--.*', '', sql).replace("\t", " ") | |
sql = re.sub(r" +", " ", sql) | |
formatted = sqlparse.format( | |
sql, | |
keyword_case="upper", | |
identifier_case="lower", | |
reindent_aligned=True | |
) | |
# Adjust spacing for readability | |
formatted = formatted.replace(", ", ",\n ") | |
if formatted.rstrip().endswith(";"): | |
formatted = formatted.rstrip()[:-1] + "\n);" | |
formatted = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", formatted) | |
ddl_statements.append(formatted) | |
ddls[db_id] = "\n\n".join(ddl_statements) | |
return ddls | |
class ChatBot: | |
""" | |
ChatBot for generating and executing SQL queries using a causal language model. | |
""" | |
def __init__(self, model_name: str = "seeklhy/codes-1b", device: str = "cuda:0") -> None: | |
""" | |
Initialize the ChatBot with model and tokenizer. | |
Args: | |
model_name (str): HuggingFace model identifier. | |
device (str): CUDA device string or 'cpu'. | |
""" | |
os.environ["CUDA_VISIBLE_DEVICES"] = device | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
self.max_length = 4096 | |
self.max_new_tokens = 256 | |
self.max_prefix_length = self.max_length - self.max_new_tokens | |
# Schema item classifier | |
self.schema_classifier = SchemaItemClassifierInference("Roxanne-WANG/LangSQL") | |
# Initialize content searchers | |
self.content_searchers = {} | |
index_dir = "db_contents_index" | |
for db_id in os.listdir(index_dir): | |
path = os.path.join(index_dir, db_id) | |
if index.exists_in(path): | |
self.content_searchers[db_id] = index.open_dir(path).searcher() | |
else: | |
raise FileNotFoundError(f"Whoosh index not found for '{db_id}' at '{path}'") | |
# Load schemas and DDLs | |
self.db_ids = sorted(os.listdir("databases")) | |
self.schemas = DatabaseUtils.get_db_schemas("databases", "data/tables.json") | |
self.ddls = DatabaseUtils.get_db_ddls("databases") | |
def get_response(self, question: str, db_id: str) -> str: | |
""" | |
Generate an executable SQL query for a natural language question. | |
Args: | |
question (str): User question in natural language. | |
db_id (str): Identifier of the target database. | |
Returns: | |
str: Executable SQL query or an error message. | |
""" | |
# Prepare data | |
schema = copy.deepcopy(self.schemas[db_id]) | |
contents = get_matched_contents(question, self.content_searchers[db_id]) | |
data = { | |
"text": question, | |
"schema": schema, | |
"matched_contents": contents | |
} | |
data = filter_schema(data, self.schema_classifier, top_k=6, top_m=10) | |
data["schema_sequence"] = get_db_schema_sequence(data["schema"]) | |
data["content_sequence"] = get_matched_content_sequence(data["matched_contents"]) | |
prefix = ( | |
f"{data['schema_sequence']}\n" | |
f"{data['content_sequence']}\n" | |
f"{question}\n" | |
) | |
# Tokenize and ensure length limits | |
input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix)["input_ids"] | |
if len(input_ids) > self.max_prefix_length: | |
input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length - 1):] | |
attention_mask = [1] * len(input_ids) | |
inputs = { | |
"input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.model.device), | |
"attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.model.device) | |
} | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=self.max_new_tokens, | |
num_beams=4, | |
num_return_sequences=4 | |
) | |
# Decode and choose executable SQL | |
decoded = self.tokenizer.batch_decode( | |
outputs[:, inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
) | |
final_sql = None | |
for sql in decoded: | |
if check_sql_executability(sql, os.path.join("databases", db_id, f"{db_id}.sqlite")) is None: | |
final_sql = sql.strip() | |
break | |
if not final_sql: | |
final_sql = decoded[0].strip() or "Sorry, I cannot generate a suitable SQL query." | |
return final_sql | |