File size: 7,706 Bytes
b759b87
 
dedd8a5
 
b759b87
dedd8a5
b759b87
13ee483
dedd8a5
 
b759b87
abb320a
dedd8a5
 
 
 
b759b87
 
dedd8a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b759b87
dedd8a5
b759b87
 
 
 
dedd8a5
 
 
 
 
 
 
 
 
 
 
036a85e
dedd8a5
036a85e
 
dedd8a5
 
 
 
b759b87
 
dedd8a5
 
b759b87
dedd8a5
b759b87
 
dedd8a5
 
 
 
 
b759b87
dedd8a5
 
b759b87
dedd8a5
b759b87
dedd8a5
 
b759b87
dedd8a5
036a85e
b759b87
dedd8a5
b759b87
dedd8a5
 
 
b759b87
 
dedd8a5
 
 
 
 
 
b759b87
 
dedd8a5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import json
import torch
import copy
import re
import sqlparse
import sqlite3

from tqdm import tqdm
from utils.db_utils import get_db_schema
from transformers import AutoModelForCausalLM, AutoTokenizer
from whoosh import index
from whoosh.index import create_in
from whoosh.fields import Schema, TEXT
from whoosh.qparser import QueryParser
from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
from schema_item_filter import SchemaItemClassifierInference, filter_schema

def remove_similar_comments(names, comments):
    '''
    Remove table (or column) comments that have a high degree of similarity with their names
    '''
    new_comments = []
    for name, comment in zip(names, comments):    
        if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
            new_comments.append("")
        else:
            new_comments.append(comment)
    
    return new_comments

def load_db_comments(table_json_path):
    additional_db_info = json.load(open(table_json_path))
    db_comments = dict()
    for db_info in additional_db_info:
        comment_dict = dict()

        column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
        table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
        column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
        
        assert len(column_names) == len(column_comments)
        column_comments = remove_similar_comments(column_names, column_comments)

        table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
        table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
        
        assert len(table_names) == len(table_comments)
        table_comments = remove_similar_comments(table_names, table_comments)

        for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
            comment_dict[table_name] = {
                "table_comment": table_comment,
                "column_comments": dict()
            }
            for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
                if t_idx == table_idx:
                    comment_dict[table_name]["column_comments"][column_name] = column_comment

        db_comments[db_info["db_id"]] = comment_dict
    
    return db_comments

def get_db_id2schema(db_path, tables_json):
    db_comments = load_db_comments(tables_json)
    db_id2schema = dict()

    for db_id in tqdm(os.listdir(db_path)):
        db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
    
    return db_id2schema

def get_db_id2ddl(db_path):
    db_ids = os.listdir(db_path)
    db_id2ddl = dict()

    for db_id in db_ids:
        conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
        cursor = conn.cursor()
        cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()
        ddl = []
        
        for table in tables:
            table_name = table[0]
            table_ddl = table[1]
            table_ddl.replace("\t", " ")
            while "  " in table_ddl:
                table_ddl = table_ddl.replace("  ", " ")
            
            table_ddl = re.sub(r'--.*', '', table_ddl)
            table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
            table_ddl = table_ddl.replace(", ", ",\n    ")
            
            if table_ddl.endswith(";"):
                table_ddl = table_ddl[:-1]
            table_ddl = table_ddl[:-1] + "\n);"
            table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n    ", table_ddl)

            ddl.append(table_ddl)
        db_id2ddl[db_id] = "\n\n".join(ddl)
    
    return db_id2ddl

class ChatBot():
    def __init__(self) -> None:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        model_name = "seeklhy/codes-1b"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
        self.max_length = 4096
        self.max_new_tokens = 256
        self.max_prefix_length = self.max_length - self.max_new_tokens

        # Directly loading the model from Hugging Face
        self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
        self.db_id2content_searcher = dict()
        for db_id in os.listdir("db_contents_index"):
            index_dir = os.path.join("db_contents_index", db_id)

            # Open existing Whoosh index directory
            if index.exists_in(index_dir):
                ix = index.open_dir(index_dir)
                # keep a searcher around for querying
                self.db_id2content_searcher[db_id] = ix.searcher()
            else:
                raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'")

        self.db_ids = sorted(os.listdir("databases"))
        self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
        self.db_id2ddl = get_db_id2ddl("databases")

    def get_response(self, question, db_id):
        data = {
            "text": question,
            "schema": copy.deepcopy(self.db_id2schema[db_id]),
            "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
        }
        data = filter_schema(data, self.sic, 6, 10)
        data["schema_sequence"] = get_db_schema_sequence(data["schema"])
        data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
        
        prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
        print(prefix_seq)
        
        input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
        if len(input_ids) > self.max_prefix_length:
            print("the current input sequence exceeds the max_tokens, we will truncate it.")
            input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
        attention_mask = [1] * len(input_ids)
        
        inputs = {
            "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
            "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
        }
        input_length = inputs["input_ids"].shape[1]

        with torch.no_grad():
            generate_ids = self.model.generate(
                **inputs,
                max_new_tokens = self.max_new_tokens,
                num_beams = 4,
                num_return_sequences = 4
            )

        generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
        final_generated_sql = None
        for generated_sql in generated_sqls:
            execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
            if execution_error is None:
                final_generated_sql = generated_sql
                break

        if final_generated_sql is None:
            if generated_sqls[0].strip() != "":
                final_generated_sql = generated_sqls[0].strip()
            else:
                final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
        
        return final_generated_sql.replace("\n", " ")