File size: 12,526 Bytes
b759b87
 
 
 
 
 
 
 
 
749f953
 
b759b87
 
 
 
 
 
 
 
 
8433589
1238103
 
 
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749f953
b759b87
 
749f953
 
b759b87
749f953
 
 
b759b87
 
749f953
 
b759b87
 
 
 
 
 
 
 
 
 
749f953
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749f953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b759b87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import os
import json
import sqlite3

from func_timeout import func_set_timeout, FunctionTimedOut
from utils.bridge_content_encoder import get_matched_entries
from nltk.tokenize import word_tokenize
from nltk import ngrams

from whoosh.qparser import QueryParser

def add_a_record(question, db_id):
    conn = sqlite3.connect('data/history/history.sqlite')
    cursor = conn.cursor()
    cursor.execute("INSERT INTO record (question, db_id) VALUES (?, ?)", (question, db_id))

    conn.commit()
    conn.close()

def obtain_n_grams(sequence, max_n):
    import nltk
    nltk.download('punkt')
    nltk.download('punkt_tab')

    tokens = word_tokenize(sequence)
    all_grams = []
    for n in range(1, max_n + 1):
        all_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)])
    
    return all_grams

# get the database cursor for a sqlite database path
def get_cursor_from_path(sqlite_path):
    try:
        if not os.path.exists(sqlite_path):
            print("Openning a new connection %s" % sqlite_path)
        connection = sqlite3.connect(sqlite_path, check_same_thread = False)
    except Exception as e:
        print(sqlite_path)
        raise e
    connection.text_factory = lambda b: b.decode(errors="ignore")
    cursor = connection.cursor()
    return cursor

# execute predicted sql with a time limitation
@func_set_timeout(15)
def execute_sql(cursor, sql):
    cursor.execute(sql)

    return cursor.fetchall()

# execute predicted sql with a long time limitation (for buiding content index)
@func_set_timeout(2000)
def execute_sql_long_time_limitation(cursor, sql):
    cursor.execute(sql)

    return cursor.fetchall()

def check_sql_executability(generated_sql, db):
    if generated_sql.strip() == "":
        return "Error: empty string"
    try:
        cursor = get_cursor_from_path(db)
        execute_sql(cursor, generated_sql)
        execution_error = None
    except FunctionTimedOut as fto:
        print("SQL execution time out error: {}.".format(fto))
        execution_error = "SQL execution times out."
    except Exception as e:
        print("SQL execution runtime error: {}.".format(e))
        execution_error = str(e)
    
    return execution_error

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def detect_special_char(name):
    for special_char in ['(', '-', ')', ' ', '/']:
        if special_char in name:
            return True

    return False

def add_quotation_mark(s):
    return "`" + s + "`"

def get_column_contents(column_name, table_name, cursor):
    select_column_sql = "SELECT DISTINCT `{}` FROM `{}` WHERE `{}` IS NOT NULL LIMIT 2;".format(column_name, table_name, column_name)
    results = execute_sql_long_time_limitation(cursor, select_column_sql)
    column_contents = [str(result[0]).strip() for result in results]
    # remove empty and extremely-long contents
    column_contents = [content for content in column_contents if len(content) != 0 and len(content) <= 25]

    return column_contents

def get_matched_contents(question, searcher):
    # Coarse-grained matching between the input text and all contents in database
    grams = obtain_n_grams(question, 4)
    hits = []
    
    # Parse each n-gram query into a valid Whoosh query object
    for query in grams:
        query_parser = QueryParser("content", schema=searcher.schema)  # 'content' should match the field you are searching
        parsed_query = query_parser.parse(query)  # Convert the query string into a Whoosh Query object
        hits.extend(searcher.search(parsed_query, limit=10))  # Perform the search with the parsed query
    
    coarse_matched_contents = dict()
    for hit in hits:
        matched_result = json.loads(hit.raw)
        # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
        tc_name = ".".join(matched_result["id"].split("-**-")[:2])
        if tc_name in coarse_matched_contents.keys():
            if matched_result["contents"] not in coarse_matched_contents[tc_name]:
                coarse_matched_contents[tc_name].append(matched_result["contents"])
        else:
            coarse_matched_contents[tc_name] = [matched_result["contents"]]
    
    fine_matched_contents = dict()
    for tc_name, contents in coarse_matched_contents.items():
        # Fine-grained matching between the question and coarse matched contents
        fm_contents = get_matched_entries(question, contents)
        
        if fm_contents is None:
            continue
        for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
            if match_score < 0.9:
                continue
            if tc_name in fine_matched_contents.keys():
                if len(fine_matched_contents[tc_name]) < 25:
                    fine_matched_contents[tc_name].append(field_value.strip())
            else:
                fine_matched_contents[tc_name] = [field_value.strip()]

    return fine_matched_contents

# def get_matched_contents(question, searcher):
#     # coarse-grained matching between the input text and all contents in database
#     grams = obtain_n_grams(question, 4)
#     hits = []
#     for query in grams:
#         hits.extend(searcher.search(query, limit = 10))
    
#     coarse_matched_contents = dict()
#     for i in range(len(hits)):
#         matched_result = json.loads(hits[i].raw)
#         # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
#         tc_name = ".".join(matched_result["id"].split("-**-")[:2])
#         if tc_name in coarse_matched_contents.keys():
#             if matched_result["contents"] not in coarse_matched_contents[tc_name]:
#                 coarse_matched_contents[tc_name].append(matched_result["contents"])
#         else:
#             coarse_matched_contents[tc_name] = [matched_result["contents"]]
    
#     fine_matched_contents = dict()
#     for tc_name, contents in coarse_matched_contents.items():
#         # fine-grained matching between the question and coarse matched contents
#         fm_contents = get_matched_entries(question, contents)
        
#         if fm_contents is None:
#             continue
#         for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
#             if match_score < 0.9:
#                 continue
#             if tc_name in fine_matched_contents.keys():
#                 if len(fine_matched_contents[tc_name]) < 25:
#                     fine_matched_contents[tc_name].append(field_value.strip())
#             else:
#                 fine_matched_contents[tc_name] = [field_value.strip()]

#     return fine_matched_contents

def get_db_schema_sequence(schema):
    schema_sequence = "database schema :\n"
    for table in schema["schema_items"]:
        table_name, table_comment = table["table_name"], table["table_comment"]
        if detect_special_char(table_name):
            table_name = add_quotation_mark(table_name)
        
        # if table_comment != "":
        #     table_name += " ( comment : " + table_comment + " )"

        column_info_list = []
        for column_name, column_type, column_comment, column_content, pk_indicator in \
            zip(table["column_names"], table["column_types"], table["column_comments"], table["column_contents"], table["pk_indicators"]):
            if detect_special_char(column_name):
                column_name = add_quotation_mark(column_name)
            additional_column_info = []
            # column type
            additional_column_info.append(column_type)
            # pk indicator
            if pk_indicator != 0:
                additional_column_info.append("primary key")
            # column comment
            if column_comment != "":
                additional_column_info.append("comment : " + column_comment)
            # representive column values
            if len(column_content) != 0:
                additional_column_info.append("values : " + " , ".join(column_content))
            
            column_info_list.append(table_name + "." + column_name + " ( " + " | ".join(additional_column_info) + " )")
        
        schema_sequence += "table "+ table_name + " , columns = [ " + " , ".join(column_info_list) + " ]\n"

    if len(schema["foreign_keys"]) != 0:
        schema_sequence += "foreign keys :\n"
        for foreign_key in schema["foreign_keys"]:
            for i in range(len(foreign_key)):
                if detect_special_char(foreign_key[i]):
                    foreign_key[i] = add_quotation_mark(foreign_key[i])
            schema_sequence += "{}.{} = {}.{}\n".format(foreign_key[0], foreign_key[1], foreign_key[2], foreign_key[3])
    else:
        schema_sequence += "foreign keys : None\n"

    return schema_sequence.strip()

def get_matched_content_sequence(matched_contents):
    content_sequence = ""
    if len(matched_contents) != 0:
        content_sequence += "matched contents :\n"
        for tc_name, contents in matched_contents.items():
            table_name = tc_name.split(".")[0]
            column_name = tc_name.split(".")[1]
            if detect_special_char(table_name):
                table_name = add_quotation_mark(table_name)
            if detect_special_char(column_name):
                column_name = add_quotation_mark(column_name)
            
            content_sequence += table_name + "." + column_name + " ( " + " , ".join(contents) + " )\n"
    else:
        content_sequence = "matched contents : None"
    
    return content_sequence.strip()

def get_db_schema(db_path, db_comments, db_id):
    if db_id in db_comments:
        db_comment = db_comments[db_id]
    else:
        db_comment = None

    cursor = get_cursor_from_path(db_path)
    
    # obtain table names
    results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
    table_names = [result[0].lower() for result in results]

    schema = dict()
    schema["schema_items"] = []
    foreign_keys = []
    # for each table
    for table_name in table_names:
        # skip SQLite system table: sqlite_sequence
        if table_name == "sqlite_sequence":
            continue
        # obtain column names in the current table
        results = execute_sql(cursor, "SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
        column_names_in_one_table = [result[0].lower() for result in results]
        column_types_in_one_table = [result[1].lower() for result in results]
        pk_indicators_in_one_table = [result[2] for result in results]

        column_contents = []
        for column_name in column_names_in_one_table:
            column_contents.append(get_column_contents(column_name, table_name, cursor))
        
        # obtain foreign keys in the current table
        results = execute_sql(cursor, "SELECT * FROM pragma_foreign_key_list('{}');".format(table_name))
        for result in results:
            if None not in [result[3], result[2], result[4]]:
                foreign_keys.append([table_name.lower(), result[3].lower(), result[2].lower(), result[4].lower()])
        
        # obtain comments for each schema item
        if db_comment is not None:
            if table_name in db_comment: # record comments for tables and columns
                table_comment = db_comment[table_name]["table_comment"]
                column_comments = [db_comment[table_name]["column_comments"][column_name] \
                    if column_name in db_comment[table_name]["column_comments"] else "" \
                        for column_name in column_names_in_one_table]
            else: # current database has comment information, but the current table does not
                table_comment = ""
                column_comments = ["" for _ in column_names_in_one_table]
        else: # current database has no comment information
            table_comment = ""
            column_comments = ["" for _ in column_names_in_one_table]

        schema["schema_items"].append({
            "table_name": table_name,
            "table_comment": table_comment,
            "column_names": column_names_in_one_table,
            "column_types": column_types_in_one_table,
            "column_comments": column_comments,
            "column_contents": column_contents,
            "pk_indicators": pk_indicators_in_one_table
        })
    
    schema["foreign_keys"] = foreign_keys
    
    return schema