Spaces:
Paused
Paused
import os | |
import json | |
import torch | |
import copy | |
import re | |
import sqlparse | |
import sqlite3 | |
from tqdm import tqdm | |
from utils.db_utils import get_db_schema | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from whoosh import index | |
from whoosh.index import create_in | |
from whoosh.fields import Schema, TEXT | |
from whoosh.qparser import QueryParser | |
from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence | |
from schema_item_filter import SchemaItemClassifierInference, filter_schema | |
def remove_similar_comments(names, comments): | |
''' | |
Remove table (or column) comments that have a high degree of similarity with their names | |
''' | |
new_comments = [] | |
for name, comment in zip(names, comments): | |
if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""): | |
new_comments.append("") | |
else: | |
new_comments.append(comment) | |
return new_comments | |
def load_db_comments(table_json_path): | |
additional_db_info = json.load(open(table_json_path)) | |
db_comments = dict() | |
for db_info in additional_db_info: | |
comment_dict = dict() | |
column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]] | |
table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]] | |
column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]] | |
assert len(column_names) == len(column_comments) | |
column_comments = remove_similar_comments(column_names, column_comments) | |
table_names = [table_name.lower() for table_name in db_info["table_names_original"]] | |
table_comments = [table_comment.lower() for table_comment in db_info["table_names"]] | |
assert len(table_names) == len(table_comments) | |
table_comments = remove_similar_comments(table_names, table_comments) | |
for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)): | |
comment_dict[table_name] = { | |
"table_comment": table_comment, | |
"column_comments": dict() | |
} | |
for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments): | |
if t_idx == table_idx: | |
comment_dict[table_name]["column_comments"][column_name] = column_comment | |
db_comments[db_info["db_id"]] = comment_dict | |
return db_comments | |
def get_db_id2schema(db_path, tables_json): | |
db_comments = load_db_comments(tables_json) | |
db_id2schema = dict() | |
for db_id in tqdm(os.listdir(db_path)): | |
db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id) | |
return db_id2schema | |
def get_db_id2ddl(db_path): | |
db_ids = os.listdir(db_path) | |
db_id2ddl = dict() | |
for db_id in db_ids: | |
conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite")) | |
cursor = conn.cursor() | |
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';") | |
tables = cursor.fetchall() | |
ddl = [] | |
for table in tables: | |
table_name = table[0] | |
table_ddl = table[1] | |
table_ddl.replace("\t", " ") | |
while " " in table_ddl: | |
table_ddl = table_ddl.replace(" ", " ") | |
table_ddl = re.sub(r'--.*', '', table_ddl) | |
table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True) | |
table_ddl = table_ddl.replace(", ", ",\n ") | |
if table_ddl.endswith(";"): | |
table_ddl = table_ddl[:-1] | |
table_ddl = table_ddl[:-1] + "\n);" | |
table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl) | |
ddl.append(table_ddl) | |
db_id2ddl[db_id] = "\n\n".join(ddl) | |
return db_id2ddl | |
class ChatBot(): | |
def __init__(self) -> None: | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
model_name = "seeklhy/codes-1b" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16) | |
self.max_length = 4096 | |
self.max_new_tokens = 256 | |
self.max_prefix_length = self.max_length - self.max_new_tokens | |
# Directly loading the model from Hugging Face | |
self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL") | |
self.db_id2content_searcher = dict() | |
for db_id in os.listdir("db_contents_index"): | |
index_dir = os.path.join("db_contents_index", db_id) | |
# Open existing Whoosh index directory | |
if index.exists_in(index_dir): | |
ix = index.open_dir(index_dir) | |
# keep a searcher around for querying | |
self.db_id2content_searcher[db_id] = ix.searcher() | |
else: | |
raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'") | |
self.db_ids = sorted(os.listdir("databases")) | |
self.db_id2schema = get_db_id2schema("databases", "data/tables.json") | |
self.db_id2ddl = get_db_id2ddl("databases") | |
def get_response(self, question, db_id): | |
data = { | |
"text": question, | |
"schema": copy.deepcopy(self.db_id2schema[db_id]), | |
"matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id]) | |
} | |
data = filter_schema(data, self.sic, 6, 10) | |
data["schema_sequence"] = get_db_schema_sequence(data["schema"]) | |
data["content_sequence"] = get_matched_content_sequence(data["matched_contents"]) | |
prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n" | |
print(prefix_seq) | |
input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"] | |
if len(input_ids) > self.max_prefix_length: | |
print("the current input sequence exceeds the max_tokens, we will truncate it.") | |
input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):] | |
attention_mask = [1] * len(input_ids) | |
inputs = { | |
"input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device), | |
"attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device) | |
} | |
input_length = inputs["input_ids"].shape[1] | |
with torch.no_grad(): | |
generate_ids = self.model.generate( | |
**inputs, | |
max_new_tokens = self.max_new_tokens, | |
num_beams = 4, | |
num_return_sequences = 4 | |
) | |
generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False) | |
final_generated_sql = None | |
for generated_sql in generated_sqls: | |
execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite")) | |
if execution_error is None: | |
final_generated_sql = generated_sql | |
break | |
if final_generated_sql is None: | |
if generated_sqls[0].strip() != "": | |
final_generated_sql = generated_sqls[0].strip() | |
else: | |
final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question." | |
return final_generated_sql.replace("\n", " ") | |