import os import json import torch import copy import re import sqlparse import sqlite3 from tqdm import tqdm from utils.db_utils import get_db_schema from transformers import AutoModelForCausalLM, AutoTokenizer from whoosh import index from whoosh.index import create_in from whoosh.fields import Schema, TEXT from whoosh.qparser import QueryParser from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence from schema_item_filter import SchemaItemClassifierInference, filter_schema def remove_similar_comments(names, comments): ''' Remove table (or column) comments that have a high degree of similarity with their names ''' new_comments = [] for name, comment in zip(names, comments): if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""): new_comments.append("") else: new_comments.append(comment) return new_comments def load_db_comments(table_json_path): additional_db_info = json.load(open(table_json_path)) db_comments = dict() for db_info in additional_db_info: comment_dict = dict() column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]] table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]] column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]] assert len(column_names) == len(column_comments) column_comments = remove_similar_comments(column_names, column_comments) table_names = [table_name.lower() for table_name in db_info["table_names_original"]] table_comments = [table_comment.lower() for table_comment in db_info["table_names"]] assert len(table_names) == len(table_comments) table_comments = remove_similar_comments(table_names, table_comments) for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)): comment_dict[table_name] = { "table_comment": table_comment, "column_comments": dict() } for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments): if t_idx == table_idx: comment_dict[table_name]["column_comments"][column_name] = column_comment db_comments[db_info["db_id"]] = comment_dict return db_comments def get_db_id2schema(db_path, tables_json): db_comments = load_db_comments(tables_json) db_id2schema = dict() for db_id in tqdm(os.listdir(db_path)): db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id) return db_id2schema def get_db_id2ddl(db_path): db_ids = os.listdir(db_path) db_id2ddl = dict() for db_id in db_ids: conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite")) cursor = conn.cursor() cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() ddl = [] for table in tables: table_name = table[0] table_ddl = table[1] table_ddl.replace("\t", " ") while " " in table_ddl: table_ddl = table_ddl.replace(" ", " ") table_ddl = re.sub(r'--.*', '', table_ddl) table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True) table_ddl = table_ddl.replace(", ", ",\n ") if table_ddl.endswith(";"): table_ddl = table_ddl[:-1] table_ddl = table_ddl[:-1] + "\n);" table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl) ddl.append(table_ddl) db_id2ddl[db_id] = "\n\n".join(ddl) return db_id2ddl class ChatBot(): def __init__(self) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = "0" model_name = "seeklhy/codes-1b" # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) # Set the device for the model (this ensures it's on either GPU or CPU) self.device = self.model.device # This will get the device the model is loaded on (either CUDA or CPU) # Define other parameters self.max_length = 4096 self.max_new_tokens = 256 self.max_prefix_length = self.max_length - self.max_new_tokens # Load the Schema Item Classifier self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL") # Initialize searcher for DB content (Whoosh index) self.db_id2content_searcher = dict() for db_id in os.listdir("db_contents_index"): index_dir = os.path.join("db_contents_index", db_id) if index.exists_in(index_dir): ix = index.open_dir(index_dir) self.db_id2content_searcher[db_id] = ix def get_response(self, question, db_id): # Prepare the data for schema filtering data = { "text": question, "schema": copy.deepcopy(self.db_id2schema[db_id]), "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id]) } # Filter schema based on predictions data = filter_schema(data, self.sic, 6, 10) data["schema_sequence"] = get_db_schema_sequence(data["schema"]) data["content_sequence"] = get_matched_content_sequence(data["matched_contents"]) # Prepare input sequence for the model prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n" input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq, truncation=False)["input_ids"] if len(input_ids) > self.max_prefix_length: input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):] attention_mask = [1] * len(input_ids) # Move input tensors to the same device as the model inputs = { "input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.device), "attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.device) } # Generate SQL query using the model with torch.no_grad(): generate_ids = self.model.generate( **inputs, max_new_tokens=self.max_new_tokens, num_beams=4, num_return_sequences=4 ) # Decode the generated SQL queries generated_sqls = self.tokenizer.batch_decode(generate_ids[:, len(input_ids):], skip_special_tokens=True, clean_up_tokenization_spaces=False) final_generated_sql = None for generated_sql in generated_sqls: execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite")) if execution_error is None: final_generated_sql = generated_sql break if final_generated_sql is None: if generated_sqls[0].strip() != "": final_generated_sql = generated_sqls[0].strip() else: final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question." return final_generated_sql.replace("\n", " ")