Spaces:
Paused
Paused
File size: 7,706 Bytes
b759b87 dedd8a5 b759b87 dedd8a5 b759b87 13ee483 dedd8a5 b759b87 abb320a dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 036a85e dedd8a5 036a85e dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 036a85e b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 b759b87 dedd8a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import json
import torch
import copy
import re
import sqlparse
import sqlite3
from tqdm import tqdm
from utils.db_utils import get_db_schema
from transformers import AutoModelForCausalLM, AutoTokenizer
from whoosh import index
from whoosh.index import create_in
from whoosh.fields import Schema, TEXT
from whoosh.qparser import QueryParser
from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
from schema_item_filter import SchemaItemClassifierInference, filter_schema
def remove_similar_comments(names, comments):
'''
Remove table (or column) comments that have a high degree of similarity with their names
'''
new_comments = []
for name, comment in zip(names, comments):
if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
new_comments.append("")
else:
new_comments.append(comment)
return new_comments
def load_db_comments(table_json_path):
additional_db_info = json.load(open(table_json_path))
db_comments = dict()
for db_info in additional_db_info:
comment_dict = dict()
column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
assert len(column_names) == len(column_comments)
column_comments = remove_similar_comments(column_names, column_comments)
table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
assert len(table_names) == len(table_comments)
table_comments = remove_similar_comments(table_names, table_comments)
for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
comment_dict[table_name] = {
"table_comment": table_comment,
"column_comments": dict()
}
for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
if t_idx == table_idx:
comment_dict[table_name]["column_comments"][column_name] = column_comment
db_comments[db_info["db_id"]] = comment_dict
return db_comments
def get_db_id2schema(db_path, tables_json):
db_comments = load_db_comments(tables_json)
db_id2schema = dict()
for db_id in tqdm(os.listdir(db_path)):
db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
return db_id2schema
def get_db_id2ddl(db_path):
db_ids = os.listdir(db_path)
db_id2ddl = dict()
for db_id in db_ids:
conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
cursor = conn.cursor()
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
ddl = []
for table in tables:
table_name = table[0]
table_ddl = table[1]
table_ddl.replace("\t", " ")
while " " in table_ddl:
table_ddl = table_ddl.replace(" ", " ")
table_ddl = re.sub(r'--.*', '', table_ddl)
table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
table_ddl = table_ddl.replace(", ", ",\n ")
if table_ddl.endswith(";"):
table_ddl = table_ddl[:-1]
table_ddl = table_ddl[:-1] + "\n);"
table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl)
ddl.append(table_ddl)
db_id2ddl[db_id] = "\n\n".join(ddl)
return db_id2ddl
class ChatBot():
def __init__(self) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_name = "seeklhy/codes-1b"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
self.max_length = 4096
self.max_new_tokens = 256
self.max_prefix_length = self.max_length - self.max_new_tokens
# Directly loading the model from Hugging Face
self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
self.db_id2content_searcher = dict()
for db_id in os.listdir("db_contents_index"):
index_dir = os.path.join("db_contents_index", db_id)
# Open existing Whoosh index directory
if index.exists_in(index_dir):
ix = index.open_dir(index_dir)
# keep a searcher around for querying
self.db_id2content_searcher[db_id] = ix.searcher()
else:
raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'")
self.db_ids = sorted(os.listdir("databases"))
self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
self.db_id2ddl = get_db_id2ddl("databases")
def get_response(self, question, db_id):
data = {
"text": question,
"schema": copy.deepcopy(self.db_id2schema[db_id]),
"matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
}
data = filter_schema(data, self.sic, 6, 10)
data["schema_sequence"] = get_db_schema_sequence(data["schema"])
data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
print(prefix_seq)
input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
if len(input_ids) > self.max_prefix_length:
print("the current input sequence exceeds the max_tokens, we will truncate it.")
input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
attention_mask = [1] * len(input_ids)
inputs = {
"input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
"attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
}
input_length = inputs["input_ids"].shape[1]
with torch.no_grad():
generate_ids = self.model.generate(
**inputs,
max_new_tokens = self.max_new_tokens,
num_beams = 4,
num_return_sequences = 4
)
generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
final_generated_sql = None
for generated_sql in generated_sqls:
execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
if execution_error is None:
final_generated_sql = generated_sql
break
if final_generated_sql is None:
if generated_sqls[0].strip() != "":
final_generated_sql = generated_sqls[0].strip()
else:
final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
return final_generated_sql.replace("\n", " ")
|