import os import json import torch import copy import re import sqlparse import sqlite3 from tqdm import tqdm from utils.db_utils import get_db_schema from transformers import AutoModelForCausalLM, AutoTokenizer from whoosh import index from whoosh.index import create_in from whoosh.fields import Schema, TEXT from whoosh.qparser import QueryParser from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence from schema_item_filter import SchemaItemClassifierInference, filter_schema def remove_similar_comments(names, comments): ''' Remove table (or column) comments that have a high degree of similarity with their names ''' new_comments = [] for name, comment in zip(names, comments): if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""): new_comments.append("") else: new_comments.append(comment) return new_comments def load_db_comments(table_json_path): additional_db_info = json.load(open(table_json_path)) db_comments = dict() for db_info in additional_db_info: comment_dict = dict() column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]] table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]] column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]] assert len(column_names) == len(column_comments) column_comments = remove_similar_comments(column_names, column_comments) table_names = [table_name.lower() for table_name in db_info["table_names_original"]] table_comments = [table_comment.lower() for table_comment in db_info["table_names"]] assert len(table_names) == len(table_comments) table_comments = remove_similar_comments(table_names, table_comments) for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)): comment_dict[table_name] = { "table_comment": table_comment, "column_comments": dict() } for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments): if t_idx == table_idx: comment_dict[table_name]["column_comments"][column_name] = column_comment db_comments[db_info["db_id"]] = comment_dict return db_comments def get_db_id2schema(db_path, tables_json): db_comments = load_db_comments(tables_json) db_id2schema = dict() for db_id in tqdm(os.listdir(db_path)): db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id) return db_id2schema def get_db_id2ddl(db_path): db_ids = os.listdir(db_path) db_id2ddl = dict() for db_id in db_ids: conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite")) cursor = conn.cursor() cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() ddl = [] for table in tables: table_name = table[0] table_ddl = table[1] table_ddl.replace("\t", " ") while " " in table_ddl: table_ddl = table_ddl.replace(" ", " ") table_ddl = re.sub(r'--.*', '', table_ddl) table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True) table_ddl = table_ddl.replace(", ", ",\n ") if table_ddl.endswith(";"): table_ddl = table_ddl[:-1] table_ddl = table_ddl[:-1] + "\n);" table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl) ddl.append(table_ddl) db_id2ddl[db_id] = "\n\n".join(ddl) return db_id2ddl class ChatBot(): def __init__(self) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = "0" model_name = "seeklhy/codes-1b" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16) self.max_length = 4096 self.max_new_tokens = 256 self.max_prefix_length = self.max_length - self.max_new_tokens # Directly loading the model from Hugging Face self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL") self.db_id2content_searcher = dict() for db_id in os.listdir("db_contents_index"): index_dir = os.path.join("db_contents_index", db_id) # Open existing Whoosh index directory if index.exists_in(index_dir): ix = index.open_dir(index_dir) # keep a searcher around for querying self.db_id2content_searcher[db_id] = ix.searcher() else: raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'") self.db_ids = sorted(os.listdir("databases")) self.db_id2schema = get_db_id2schema("databases", "data/tables.json") self.db_id2ddl = get_db_id2ddl("databases") def get_response(self, question, db_id): data = { "text": question, "schema": copy.deepcopy(self.db_id2schema[db_id]), "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id]) } data = filter_schema(data, self.sic, 6, 10) data["schema_sequence"] = get_db_schema_sequence(data["schema"]) data["content_sequence"] = get_matched_content_sequence(data["matched_contents"]) prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n" print(prefix_seq) input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"] if len(input_ids) > self.max_prefix_length: print("the current input sequence exceeds the max_tokens, we will truncate it.") input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):] attention_mask = [1] * len(input_ids) inputs = { "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device), "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device) } input_length = inputs["input_ids"].shape[1] with torch.no_grad(): generate_ids = self.model.generate( **inputs, max_new_tokens = self.max_new_tokens, num_beams = 4, num_return_sequences = 4 ) generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False) final_generated_sql = None for generated_sql in generated_sqls: execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite")) if execution_error is None: final_generated_sql = generated_sql break if final_generated_sql is None: if generated_sqls[0].strip() != "": final_generated_sql = generated_sqls[0].strip() else: final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question." return final_generated_sql.replace("\n", " ")